Class incremental learning of the GTSRB dataset¶


The goal of the project is to develop a solution for recognizing images of road signs. This is a classical dataset, ans it is available in many repositories, e.g. torchvision.

Images have small size (32x32) and the number of data is also small. A deep network can run on a (good) CPU in a reasonable time. GPU will make it faster of course.

The question to be addressed in the project is to introduce sequentially new categories to classify: this is called class-incremental learning in the literature.

The problem of this learning setting is that it produces "catastrophic forgetting" when new categories (new road signs) to recognize are sequentially added to the recognition system: the old categories are forgotten if nothing is done.

Several strategies can be used to avoid forgetting:

  • Rehearsal (maintain a small memory buffer of previous examples)
  • Regularization (Knowledge distillation, EWC...)
  • Incremental architecture
  • ...

The figure below shows two learning strategies when new classes, 8 at each time, are sequentially added.

  • The "upper bound" curve is generated by using all data available (non incremental) and is the expected target performance.

  • The "Fine tuning" is obtained by simply applying a gradient descent on the new data: it is the simplest strategy and it produces catastrophic forgetting.

  • The "rehearsal KD" curve is obtained using a rehearsal memory buffer of size 200 and a Knowledge Distillation between old and new classes.

image.png

A few references to help you find a solution:

  • Tutorial: https://sites.google.com/view/neurips2022-llm-tutorial

  • Recent survey: Zhou, D. W., Wang, Q. W., Qi, Z. H., Ye, H. J., Zhan, D. C., & Liu, Z. (2024). Class-incremental learning: A survey. IEEE Transactions on Pattern Analysis and Machine Intelligence: https://github.com/zhoudw-zdw/CIL_Survey/

  • Huge list of papers: https://github.com/ContinualAI/continual-learning-papers

  • Van de Ven, G. M., & Tolias, A. S. (2019). Three scenarios for continual learning: https://arxiv.org/abs/1904.07734, https://github.com/GMvandeVen/continual-learning

  • A simple strategy: Prabhu, A., Torr, P. H., & Dokania, P. K. (2020). Gdumb: A simple approach that questions our progress in continual learning. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part II 16 (pp. 524-540). https://github.com/drimpossible/GDumb

A jupyter notebook to help you start (data loders, basline for class incremental learning): Open In Colab


Class incremental learning on the GTSRB dataset¶


This notebook contains several code snippets to help for your project:

  • data loaders
  • A baseline for incremental learning using fine-tuning
  • Examples of how to use Weight & Biases for logging your results.

image.png

$\leadsto$   Remark 1

  • In Section 9: Rehearsal Memory Buffer + Knowledge Distillation Loss (KD), we implement class-incremental learning by using a rehearsal memory buffer of fixed total size 200 and apply Knowledge Distillation (KD) to transfer knowledge from old to new classes. [1] [2]

  • In Section 10: Dynamically Expandable Representation (DER), we implement class-incremental learning using a rehearsal memory buffer where each class stores a fixed number of 10 samples. Additionally, we apply Knowledge Distillation between old and new classes, combined with a Dynamically Expandable Representation (DER_CNN) model, which dynamically expands the classifier to accommodate new classes while preserving old knowledge. [3] [4]

References

[1] Zhou, D. W., Wang, Q. W., Qi, Z. H., Ye, H. J., Zhan, D. C., & Liu, Z. (2024). Class-Incremental Learning: A Survey. IEEE Transactions on Pattern Analysis and Machine Intelligence. PDF

[2] Zhou et al. (2024). iCaRL Implementation. GitHub Repository. iCaRL.py

[3] Yan, S., Xie, J., Wang, C., Gong, Y., & Yan, J. (2021). DER: Dynamically Expandable Representation for Class Incremental Learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). PDF

[4] Yan et al. (2021). DER Implementation. GitHub Repository. main.py

Section 0: Importing Common Libraries and Modules¶

0.1 Importing Essential Libraries¶

In [3]:
####################################
'''A.PyTorch Modules'''
import torch
import torch.nn as nn                     # 神经网络模块 / Neural network modules
import torch.nn.init as init              # 网络参数初始化 / Network parameter initialization
import torch.optim as optim               # 优化器 / Optimizers
import torch.nn.functional as F           # 常用函数(如激活函数)/ Common functions like activation functions

'''B.Torchvision Modules'''
from torch.utils.data import DataLoader, ConcatDataset, Subset  # 数据加载与处理 / Data loading and handling
from torchvision.utils import make_grid   # 图像可视化工具 / For visualizing images
from torchvision import transforms, datasets  # 图像数据集和数据增强 / Image datasets and data augmentation
import torchvision.models as models       # 预训练模型 / Pre-trained models
from torchvision.transforms import v2     # torchvision 的新版数据增强接口 / Newer data augmentation API in torchvision

'''C.Helper Libraries'''
import copy                               # 深拷贝模型或数据 / For deep copying models or data

import numpy as np                        # 科学计算库 / Scientific computing library
import random                             # 随机数生成 / Random number generation
import time, os                           # 计时与操作系统交互 / Timing and OS interaction
import matplotlib.pyplot as plt           # 绘图库 / Plotting library
from PIL import Image                     # 图像处理 / Image processing
from tqdm import tqdm                     # 进度条 / Progress bar
import pandas as pd                       # 数据表处理 / Dataframe handling
import math                               # 数学函数 / Math functions

0.2 Device Configuration¶

In [4]:
# 设备配置,优先使用 GPU / Device configuration, prioritize GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
In [5]:
# # Useful if you want to store intermediate results on your drive
# from google.colab import drive

# # Useful if you want to store intermediate results on your drive from google.colab import drive

# drive.mount('/content/gdrive/')
# DATA_DIR =  '/content/gdrive/MyDrive/teaching/ENSTA/2024'
In [6]:
# Check if GPU is available
if torch.cuda.is_available():
  !nvidia-smi
Sat Mar 22 05:08:51 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.01              Driver Version: 546.01       CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA T1200 Laptop GPU        On  | 00000000:01:00.0 Off |                  N/A |
| N/A   52C    P0              17W /  90W |      0MiB /  4096MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

image.png

Section 1: Data loaders¶

In [7]:
'''A. 定义数据增强 / Define Transformations'''

# 训练数据增强 / Training data transformations
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(32),           # 随机裁剪并缩放到32x32 / Random resized crop to 32x32
    transforms.RandomHorizontalFlip(),          # 随机水平翻转 / Random horizontal flip
    transforms.ToTensor(),                      # 转换为张量 / Convert to tensor
    # transforms.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669)) # GTSRB 数据集的均值和标准差 / GTSRB dataset mean and std
])

# 测试数据增强 / Testing data transformations
transform_test = transforms.Compose([
    transforms.Resize(32),                      # 缩放到32x32 / Resize to 32x32
    transforms.ToTensor(),                      # 转换为张量 / Convert to tensor
    # transforms.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669))
])

'''B. 使用新版 torchvision.transforms v2 接口 / Using New torchvision.transforms v2 API'''

# 训练数据增强 / Training data transformations (v2)
transform_train = v2.Compose([
    # v2.Grayscale(),                           # 可以转换为灰度 / Optional grayscale conversion
    # v2.RandomResizedCrop(32),
    v2.Resize((32, 32)),                        # 调整为32x32 / Resize to 32x32
    v2.ToImage(),                               # 转换为图像格式 / Convert to image format
    v2.ToDtype(torch.float32, scale=True),      # 转换为 float32 类型并缩放到 [0,1] / Convert to float32 and scale
    # v2.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669)) 
])

# 测试数据增强 / Testing data transformations (v2)
transform_test = v2.Compose([
    # v2.Grayscale(),
    v2.Resize((32, 32)),                        # 调整为32x32 / Resize to 32x32
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    # v2.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669))
])

'''C. 定义数据集和数据加载器 / Define Dataset and DataLoader'''

# 获取数据集 / Get dataset
def get_dataset(root_dir, transform, train=True):
    """
    root_dir: 数据集存储路径 / Path to store dataset
    transform: 数据增强方式 / Transformations applied
    train: 是否加载训练集 / Whether to load training set
    """
    dataset = datasets.GTSRB(root=root_dir, split='train' if train else 'test', download=True, transform=transform)
    target = [data[1] for data in dataset]  # 获取标签 / Extract labels
    return dataset, target

# 创建数据加载器 / Create DataLoader
def create_dataloader(dataset, targets, current_classes, batch_size, shuffle):
    """
    dataset: 完整数据集 / Full dataset
    targets: 标签列表 / List of labels
    current_classes: 当前要训练的类别 / Current classes to be used
    batch_size: 批大小 / Batch size
    shuffle: 是否打乱数据 / Whether to shuffle data
    """
    indices = [i for i, label in enumerate(targets) if label in current_classes]  # 只选取当前类别的数据 / Filter by current classes
    subset = Subset(dataset, indices)                                             # 创建子集 / Create subset
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle)       # 创建数据加载器 / Create DataLoader
    return dataloader
In [8]:
'''D. 指定数据集存储路径 & 加载完整训练集和测试集 / Specify Dataset Directory & Load Datasets'''
# 指定数据集所在文件夹路径
# Specify dataset directory
root_dir = './data'

# 加载训练集和测试集 / Load training and test datasets
train_dataset = datasets.GTSRB(root=root_dir, split='train', download=True, transform=transform_train)
test_dataset = datasets.GTSRB(root=root_dir, split='test', download=True, transform=transform_test)

print(f"Train dataset contains {len(train_dataset)} images")
print(f"Test dataset contains {len(test_dataset)} images")


'''E. 加载标签列表和类别名称文件 / Load Target Labels and Class Names'''
# 加载本地的标签列表和类别名称文件 / Load local target ID lists and class names

import csv

# 加载测试集标签 / Load test target IDs
data = pd.read_csv('./test_target.csv', delimiter=',', header=None)
test_target = data.to_numpy().squeeze().tolist()

# 加载训练集标签 / Load train target IDs
data = pd.read_csv('./train_target.csv', delimiter=',', header=None)
train_target = data.to_numpy().squeeze().tolist()

# 加载类别名称 / Load class names
data = pd.read_csv('./signnames.csv')
class_names = data['SignName'].tolist()
Train dataset contains 26640 images
Test dataset contains 12630 images
In [9]:
'''
# Loads datasets (on your local computer)
root_dir = '/home/stephane/Documents/Onera/Cours/ENSTA/2025/data'

# Loads datasets (on Colab local computer)
root_dir = './data'

train_dataset = datasets.GTSRB(root=root_dir, split='train', download=True, transform=transform_train)
test_dataset = datasets.GTSRB(root=root_dir, split='test', download=True, transform=transform_test)

print(f"Train dataset contains {len(train_dataset)} images")
print(f"Test dataset contains {len(test_dataset)} images")

# Loads target id lists and class names (not in torchvision dataset)
import csv
data = pd.read_csv('https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/test_target.csv', delimiter=',', header=None)
test_target = data.to_numpy().squeeze().tolist()

data = pd.read_csv('https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/train_target.csv', delimiter=',', header=None)
train_target = data.to_numpy().squeeze().tolist()

data = pd.read_csv('https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/signnames.csv')
class_names = data['SignName'].tolist()'
'''
Out[9]:
'\n# Loads datasets (on your local computer)\nroot_dir = \'/home/stephane/Documents/Onera/Cours/ENSTA/2025/data\'\n\n# Loads datasets (on Colab local computer)\nroot_dir = \'./data\'\n\ntrain_dataset = datasets.GTSRB(root=root_dir, split=\'train\', download=True, transform=transform_train)\ntest_dataset = datasets.GTSRB(root=root_dir, split=\'test\', download=True, transform=transform_test)\n\nprint(f"Train dataset contains {len(train_dataset)} images")\nprint(f"Test dataset contains {len(test_dataset)} images")\n\n# Loads target id lists and class names (not in torchvision dataset)\nimport csv\ndata = pd.read_csv(\'https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/test_target.csv\', delimiter=\',\', header=None)\ntest_target = data.to_numpy().squeeze().tolist()\n\ndata = pd.read_csv(\'https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/train_target.csv\', delimiter=\',\', header=None)\ntrain_target = data.to_numpy().squeeze().tolist()\n\ndata = pd.read_csv(\'https://raw.githubusercontent.com/stepherbin/teaching/refs/heads/master/IOGS/projet/signnames.csv\')\nclass_names = data[\'SignName\'].tolist()\'\n'

image.png

Section 2: Display of images¶

2.1 Display of images¶

In [10]:
# 获取训练集中的不同类别数量
# Get the number of unique classes in the training set
nclasses = len(np.unique(train_target))

# 创建一个包含所有类别的列表
# Create a list of all class indices
all_classes = list(range(nclasses))

# 如果需要可以打乱类别顺序 (这里注释掉了)
# You can shuffle the class order if needed (currently commented out)
# random.shuffle(all_classes)

# 每个任务包含多少个类别
# Number of classes per task
classes_per_task = 8

# 当前任务要使用的类别列表
# List of classes for the current task
current_classes = []

# 任务编号,从0开始
# Task index, starting from 0
task = 0

# 选取当前任务要训练的类别
'''
每个任务选取连续的 8 个类别,前8个、下8个、再下8个……,方便一批一批训练不同类别!:
task 0 : task_classes   [0,1,2,3,4,5,6,7]
task 1 : task_classes	[8,9,10,11,12,13,14,15]
task 2 : task_classes   [16,17,18,19,20,21,22,23]
...
'''
# Select classes for the current task
task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]

# 添加到当前类别列表
# Add selected classes to current_classes
current_classes.extend(task_classes)

# 批大小
# Batch size
batch_size = 64

# ==========================
#  创建当前任务的训练和测试数据加载器
# Create DataLoader for training and testing for current task
train_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle=True)
'''
test_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle=True)'
'''
test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle=True)

# ==========================
#  显示数据样例
# Function to display image samples
def show(img):
    npimg = img.numpy()
    # 转置维度,显示图片
    # Transpose dimensions and display image
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')

# 从训练数据加载器中取一个 batch
# Get one batch from the train loader
sample, targets = next(iter(train_loader))

# 显示这些样例图片
# Display the image grid
show(make_grid(sample))
plt.show()

# ==========================
# 输出 batch 的 shape
print(sample.shape)

# 解释:
# 64 是 batch 大小
# 64 is the batch size

# 第二个维度是通道数:
# 1 表示灰度图,3 表示彩色 RGB 图
# Second dimension is the number of channels:
# 1 for grayscale, 3 for RGB

# 后两个是图片的高和宽(这里是 32x32)
# Last two dimensions are image height and width (32x32 here)
No description has been provided for this image
torch.Size([64, 3, 32, 32])

2.2 Advanced Batch Visualization¶

In [11]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
In [12]:
def show_batch(img, targets):
    """
    显示一批图像和对应标签,8x8 网格
    Display a batch of images and their labels in an 8x8 grid
    """
    # 创建 8x8 图片网格
    grid_img = make_grid(img, nrow=8, padding=2)

    plt.figure(figsize=(6, 6))
    npimg = grid_img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')

    # 循环为每张小图加标签
    for idx in range(len(targets)):
        row = idx // 8
        col = idx % 8
        plt.text(
            x=col * (32 + 2) + 4,  # X 坐标,带padding
            y=row * (32 + 2) + 30, # Y 坐标
            s=str(targets[idx].item()),  # 标签数字
            color='white',
            fontsize=10,
            bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2')
        )
    
    plt.show()
In [13]:
show_batch(sample, targets)
No description has been provided for this image

image.png

In [14]:
'''
nclasses = len(np.unique(train_target))
all_classes = list(range(nclasses))
#random.shuffle(all_classes)
classes_per_task = 8
current_classes = []

task = 0
task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
current_classes.extend(task_classes)
batch_size = 64

# Create data for first task
train_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle = True)
test_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle = True)

# Displays a few examples
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

sample,targets = next(iter(train_loader))
show(make_grid(sample))
plt.show()

print(sample.shape)     ## 64 is the batch
                        ## 1 for grey values --  3 for RGB
                        ## 32x32 for mage size (small here)
                        
'''
Out[14]:
"\nnclasses = len(np.unique(train_target))\nall_classes = list(range(nclasses))\n#random.shuffle(all_classes)\nclasses_per_task = 8\ncurrent_classes = []\n\ntask = 0\ntask_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]\ncurrent_classes.extend(task_classes)\nbatch_size = 64\n\n# Create data for first task\ntrain_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle = True)\ntest_loader = create_dataloader(train_dataset, train_target, current_classes, batch_size, shuffle = True)\n\n# Displays a few examples\ndef show(img):\n    npimg = img.numpy()\n    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')\n\nsample,targets = next(iter(train_loader))\nshow(make_grid(sample))\nplt.show()\n\nprint(sample.shape)     ## 64 is the batch\n                        ## 1 for grey values --  3 for RGB\n                        ## 32x32 for mage size (small here)\n                        \n"

2.3 Test Set Label Distribution Analysis¶

In [15]:
# 创建测试数据加载器,包含所有类别
# Create DataLoader for test set, including all classes
'''
test_loader = create_dataloader(train_dataset, train_target, all_classes, batch_size, shuffle=True)'
'''
test_loader = create_dataloader(test_dataset, test_target, all_classes, batch_size, shuffle=True)

# ==========================
# (可选)从 DataLoader 获取测试数据标签,注释掉的是另一种方法
# (Optional) Extract test labels from DataLoader - commented out here

# gtsrbtest_gt = []  # 用于存储所有标签 / To store all targets
# for _, targets in test_loader:
#     gtsrbtest_gt += targets.numpy().tolist()  # 转为列表追加 / Convert targets to list and append
# print(len(gtsrbtest_gt))  # 打印标签总数 / Print total number of targets

# ==========================
# 直接统计 test_target 列表中的标签分布
# Directly count label distribution in test_target

from collections import Counter

# 统计每个类别标签的数量,并按数量降序排列
# Count occurrences of each label, sort by most common
label_counts = Counter(test_target).most_common()

# 打印每个类别的统计结果
# Print statistics for each class
print("Count\tClassID\tClassName")
for l, c in label_counts:
    print(f"{c}\t{l}\t{class_names[l]}")  # 输出格式: 数量, 类别编号, 类别名称
Count	ClassID	ClassName
750	2	Speed limit (50km/h)
720	1	Speed limit (30km/h)
720	13	Yield
690	38	Keep right
690	12	Priority road
660	4	Speed limit (70km/h)
660	10	No passing for vechiles over 3.5 metric tons
630	5	Speed limit (80km/h)
480	25	Road work
480	9	No passing
450	7	Speed limit (100km/h)
450	3	Speed limit (60km/h)
450	8	Speed limit (120km/h)
420	11	Right-of-way at the next intersection
390	18	General caution
390	35	Ahead only
360	17	No entry
270	14	Stop
270	31	Wild animals crossing
210	33	Turn right ahead
210	15	No vechiles
180	26	Traffic signals
150	16	Vechiles over 3.5 metric tons prohibited
150	23	Slippery road
150	30	Beware of ice/snow
150	28	Children crossing
150	6	End of speed limit (80km/h)
120	34	Turn left ahead
120	22	Bumpy road
120	36	Go straight or right
90	21	Double curve
90	20	Dangerous curve to the right
90	24	Road narrows on the right
90	29	Bicycles crossing
90	40	Roundabout mandatory
90	39	Keep left
90	42	End of no passing by vechiles over 3.5 metric tons
60	27	Pedestrians
60	32	End of all speed and passing limits
60	41	End of no passing
60	19	Dangerous curve to the left
60	0	Speed limit (20km/h)
60	37	Go straight or left
In [16]:
'''
test_loader = create_dataloader(train_dataset, train_target, all_classes, batch_size, shuffle = True)

# Get the data from the test set and computes statistics
# gtsrbtest_gt = []
# for _, targets in test_loader:
#   gtsrbtest_gt += targets.numpy().tolist()
# print(len(gtsrbtest_gt))

from collections import Counter

label_counts = Counter(test_target).most_common()
for l, c in label_counts:
    print(c, '\t', l, '\t', class_names[l])
'''
Out[16]:
"\ntest_loader = create_dataloader(train_dataset, train_target, all_classes, batch_size, shuffle = True)\n\n# Get the data from the test set and computes statistics\n# gtsrbtest_gt = []\n# for _, targets in test_loader:\n#   gtsrbtest_gt += targets.numpy().tolist()\n# print(len(gtsrbtest_gt))\n\nfrom collections import Counter\n\nlabel_counts = Counter(test_target).most_common()\nfor l, c in label_counts:\n    print(c, '\t', l, '\t', class_names[l])\n"

Section 3: Simple networks¶

3.1 Definition of Simple CNN Models¶

In [17]:
# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self,n_out=10, n_in=1):
        super().__init__()

        # Put the layers here
        self.conv1 = nn.Conv2d(n_in, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        self.fc = nn.Linear(4096, n_out)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x)) ## l'image 1x32x32 devient 32x32x32
        x = F.max_pool2d(x, kernel_size=2, stride=2) ## puis 32x16x16
        x = F.leaky_relu(self.conv2(x)) ## puis devient 64x16x16
        x = F.max_pool2d(x, kernel_size=2, stride=2) ## puis devient 64x8x8
        x = F.leaky_relu(self.conv3(x)) ## pas de changement

        x = x.view(-1,4096) ## 64x8x8 devient 4096

        x = self.fc(x) ## on finit exactement de la même façon

        return x

# Another simple model (compare them using torchinfo below)
class SimpleCNN2(nn.Module):
    def __init__(self, n_out=10, n_in=1):
        super(SimpleCNN2, self).__init__()
        self.conv1 = nn.Conv2d(n_in, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc = nn.Linear(128, n_out)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)  # Flatten the tensor
        x = self.relu(self.fc1(x))
        x = self.fc(x)
        return x

3.2 Model Summary¶

In [18]:
!pip install torchinfo
Requirement already satisfied: torchinfo in /home/home/miniconda3/envs/tensor/lib/python3.11/site-packages (1.8.0)
In [19]:
from torchinfo import summary
model = SimpleCNN(n_out=10, n_in=3)
model.to(device)
print(summary(model, input_size=(batch_size, 3, 32, 32)))

model = SimpleCNN2(n_out=10, n_in=3)
model.to(device)
print(summary(model, input_size=(batch_size, 3, 32, 32)))

#print(model)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
SimpleCNN                                [64, 10]                  --
├─Conv2d: 1-1                            [64, 32, 32, 32]          2,432
├─Conv2d: 1-2                            [64, 64, 16, 16]          18,496
├─Conv2d: 1-3                            [64, 64, 8, 8]            36,928
├─Linear: 1-4                            [64, 10]                  40,970
==========================================================================================
Total params: 98,826
Trainable params: 98,826
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 616.30
==========================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 27.27
Params size (MB): 0.40
Estimated Total Size (MB): 28.45
==========================================================================================
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
SimpleCNN2                               [64, 10]                  --
├─Conv2d: 1-1                            [64, 32, 32, 32]          896
├─ReLU: 1-2                              [64, 32, 32, 32]          --
├─MaxPool2d: 1-3                         [64, 32, 16, 16]          --
├─Conv2d: 1-4                            [64, 64, 16, 16]          18,496
├─ReLU: 1-5                              [64, 64, 16, 16]          --
├─MaxPool2d: 1-6                         [64, 64, 8, 8]            --
├─Linear: 1-7                            [64, 128]                 524,416
├─ReLU: 1-8                              [64, 128]                 --
├─Linear: 1-9                            [64, 10]                  1,290
==========================================================================================
Total params: 545,098
Trainable params: 545,098
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 395.40
==========================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 25.24
Params size (MB): 2.18
Estimated Total Size (MB): 28.20
==========================================================================================

Section 4: Baseline for incremental learning¶

In [20]:
from torch.optim import lr_scheduler
import torch.nn.init as init

# Evaluation
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader, ncols=80):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# Simple Training loop
def train(model, train_loader, optimizer, criterion, device, epoch):
    model.train()

    for images, labels in tqdm(train_loader, ncols=80,  desc="Epoch {}".format(epoch)):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

def initialize_weights(module):
    """Initializes the weights of a PyTorch module using Xavier/Glorot initialization."""
    if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):  # Check for relevant layers
        init.xavier_uniform_(module.weight) #Xavier uniform initialization
        if module.bias is not None:
            init.zeros_(module.bias)  # Initialize bias to zero
    elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): #Initialize normalization layers
        if module.weight is not None:
            init.ones_(module.weight)
        if module.bias is not None:
            init.zeros_(module.bias)


# Main training loop for incremental learning
def incremental_learning(model, train_dataset, train_target, test_dataset, test_target,
                         num_tasks, classes_per_task, batch_size, num_epochs, lr, device):
    nclasses = len(np.unique(train_target))
    all_classes = list(range(nclasses))
    criterion = nn.CrossEntropyLoss()
    current_classes = []
    accuracies = []

    for task in range(num_tasks):
        task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
        current_classes.extend(task_classes)

        train_loader = create_dataloader(train_dataset, train_target, task_classes, batch_size, shuffle = True)
        test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle = False)

        if task == 0:
            model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
        else:
            # Expand the output layer for new classes
            old_weight = model.fc.weight.data
            old_bias = model.fc.bias.data
            model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
            model.fc.weight.data[:len(old_weight)] = old_weight
            model.fc.bias.data[:len(old_bias)] = old_bias

        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

        print(f"Starting Task {task+1} - Training on classes: {task_classes}")
        for epoch in range(num_epochs): # Adjust number of epochs as needed
            train(model, train_loader, optimizer, criterion, device, epoch)
            scheduler.step()
            accuracy = evaluate(model, train_loader, device)
            print(f"Task {task+1}, Epoch {epoch+1}: Accuracy Train = {accuracy:.2f}%")
        accuracy = evaluate(model, test_loader, device)
        accuracies.append(accuracy)
        print(f"Task {task+1}: Accuracy Test = {accuracy:.2f}%")

    return accuracies

Section 5: Weight & Biases¶

You can use this environement to log your learning.

The code below provides a version of the class incremental function that stores learning curves and the seauence of accuracies for each increment of classes.

Tu use it, create an account at: https://wandb.ai/

In [21]:
###################################
##### For using Weight & Biases
###############

!pip install wandb -qU

import wandb

wandb.login()
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: kaiyuan_xu (kaiyuan_xu_upsud) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
Out[21]:
True
In [22]:
import math
# Simple Training loop
def train_wandb(model, train_loader, optimizer, criterion, device, epoch):

    step_ct = 0
    n_steps_per_epoch = math.ceil(len(train_loader.dataset) / train_loader.batch_size)

    model.train()

    for step, (images, labels) in tqdm(enumerate(train_loader), ncols=80,  desc="Epoch {}".format(epoch)):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        metrics = {"train/train_loss": loss}
        # metrics = {"train/train_loss": loss,
        #             "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch}

        if step + 1 < n_steps_per_epoch:
          # Log train metrics to wandb
          wandb.log(metrics)
        step_ct += 1


# Main training loop for incremental learning
def incremental_learning_wandb(model, train_dataset, train_target, test_dataset, test_target,
                         num_tasks, classes_per_task, batch_size, num_epochs, lr, device, non_incremental = False):
    nclasses = len(np.unique(train_target))
    all_classes = list(range(nclasses))
    criterion = nn.CrossEntropyLoss()
    current_classes = []
    accuracies = []

    # Copy your config
    config = wandb.config

    wandb.define_metric("task")
    wandb.define_metric("incremental_accuracy", step_metric="task")

    for task in range(num_tasks):
        if non_incremental == True: # Learn from all available data
          task_classes = all_classes[0 : (task + 1) * classes_per_task]
          current_classes = task_classes
          model.apply(initialize_weights)
        else:
          task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
          current_classes.extend(task_classes)

        train_loader = create_dataloader(train_dataset, train_target, task_classes, batch_size, shuffle = True)
        test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle = False)

        if task == 0 or non_incremental == True:
            model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
        else:
            # Expand the output layer for new classes
            old_weight = model.fc.weight.data
            old_bias = model.fc.bias.data
            model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
            model.fc.weight.data[:len(old_weight)] = old_weight
            model.fc.bias.data[:len(old_bias)] = old_bias

        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

        print(f"Starting Task {task+1} - Training on classes: {task_classes}")
        for epoch in range(num_epochs): # Adjust number of epochs
            #train(model, train_loader, optimizer, criterion, device, epoch)

            # If logging training (but not incremental)
            train_wandb(model, train_loader, optimizer, criterion, device, epoch)

            scheduler.step()
            accuracy = evaluate(model, train_loader, device)
            print(f"Task {task+1}, Epoch {epoch+1}: Accuracy Train = {accuracy:.2f}%")

            val_metrics = {"val/val_accuracy": accuracy}
            #wandb.log({**val_metrics})

        accuracy = evaluate(model, test_loader, device)
        accuracies.append(accuracy)
        print(f"Task {task+1}: Accuracy Test = {accuracy:.2f}%")

        incremental_metrics = {"incremental_accuracy": accuracy, "task": task}
        wandb.log({**incremental_metrics})

        # Log train and validation metrics to wandb

    return accuracies

Section 6: Pre-Training¶

In [23]:
# Hyperparameters
root_dir = './data'  # Path to GTSRB dataset
num_tasks = 5
numclasses = len(np.unique(train_target))
classes_per_task = numclasses // num_tasks #43/2 ~ 20
batch_size = 64
lr = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

buffer_size = 200 # Adjust rehearsal set size
alignment_strength = 0.1 # Adjust alignment strength
num_epochs = 4

#model = SimpleCNN(n_out = 1, n_in = 3).to(device)
#model.apply(initialize_weights)

# The name of the network (choose the on you want)
tag = "simpleCNN_GTSRB_pretrained"
netname = os.path.join(root_dir, 'network_{:s}.pth'.format(tag))

#################################################
## Pre-training
####

# Read the last learned network (if stored)
if (os.path.exists(netname)):
    print('Load pre-trained network')
    model = SimpleCNN(n_in = 3, n_out=classes_per_task)
    model.load_state_dict(torch.load(netname,weights_only=True))

    #model = torch.load(netname, weights_only=True)
    model = model.to(device)
else:
    print('Pretrain')
    model = SimpleCNN(n_in = 3, n_out=1)
    model.apply(initialize_weights)
    model.to(device)

    accu = incremental_learning(model, train_dataset, train_target, test_dataset, test_target,
                        1, classes_per_task, batch_size, num_epochs, lr, device)

    print(f"!!!!! Pre-training on first task  = {accu[0]:.2f}%")

    # Save last learned model
    #torch.save(model, netname)
    torch.save(model.state_dict(), netname)

## Copy model to have the same initialization
copy_model = copy.deepcopy(model) # Copy model to start from the same initialization

#### Learn with a single epoch in incremental (faster but less accurate)
num_epochs = 1
Load pre-trained network

Section 7: Fine tuning¶

In [24]:
#############################################
## Fine tuning
####
# initialise a wandb run
num_epochs = 5

run = wandb.init(
    project="GTSRB-CIL",
    name = "Fine tuning",
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "num_tasks": num_tasks,
        "classes_per_task": classes_per_task,
        "lr": lr,
        })

# Simple Incremental Fine Tuning
model = copy.deepcopy(copy_model)
incremental_learning_wandb(model, train_dataset, train_target, test_dataset, test_target,
                      num_tasks, classes_per_task, batch_size, num_epochs, lr, device)
wandb.finish()
Tracking run with wandb version 0.19.8
Run data is saved locally in /home/home/projects/tensor/projet_Machine_learning/wandb/run-20250322_050856-252cfjp6
Syncing run Fine tuning to Weights & Biases (docs)
View project at https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
View run at https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL/runs/252cfjp6
Starting Task 1 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7]
Epoch 0: 125it [00:04, 25.83it/s]
100%|█████████████████████████████████████████| 125/125 [00:02<00:00, 54.40it/s]
Task 1, Epoch 1: Accuracy Train = 94.25%
Epoch 1: 125it [00:02, 47.46it/s]
100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 65.62it/s]
Task 1, Epoch 2: Accuracy Train = 98.83%
Epoch 2: 125it [00:02, 43.63it/s]
100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 66.64it/s]
Task 1, Epoch 3: Accuracy Train = 98.87%
Epoch 3: 125it [00:02, 44.90it/s]
100%|█████████████████████████████████████████| 125/125 [00:02<00:00, 62.00it/s]
Task 1, Epoch 4: Accuracy Train = 99.70%
Epoch 4: 125it [00:02, 45.29it/s]
100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 66.70it/s]
Task 1, Epoch 5: Accuracy Train = 99.79%
100%|███████████████████████████████████████████| 61/61 [00:01<00:00, 34.98it/s]
Task 1: Accuracy Test = 91.71%
Starting Task 2 - Training on classes: [8, 9, 10, 11, 12, 13, 14, 15]
Epoch 0: 126it [00:06, 18.06it/s]
100%|█████████████████████████████████████████| 126/126 [00:02<00:00, 47.69it/s]
Task 2, Epoch 1: Accuracy Train = 99.40%
Epoch 1: 126it [00:02, 50.76it/s]
100%|█████████████████████████████████████████| 126/126 [00:01<00:00, 66.30it/s]
Task 2, Epoch 2: Accuracy Train = 99.95%
Epoch 2: 126it [00:02, 44.91it/s]
100%|█████████████████████████████████████████| 126/126 [00:01<00:00, 63.84it/s]
Task 2, Epoch 3: Accuracy Train = 99.96%
Epoch 3: 126it [00:02, 44.75it/s]
100%|█████████████████████████████████████████| 126/126 [00:01<00:00, 64.95it/s]
Task 2, Epoch 4: Accuracy Train = 99.99%
Epoch 4: 126it [00:02, 44.55it/s]
100%|█████████████████████████████████████████| 126/126 [00:01<00:00, 66.12it/s]
Task 2, Epoch 5: Accuracy Train = 99.98%
100%|█████████████████████████████████████████| 122/122 [00:03<00:00, 39.77it/s]
Task 2: Accuracy Test = 49.36%
Starting Task 3 - Training on classes: [16, 17, 18, 19, 20, 21, 22, 23]
Epoch 0: 49it [00:02, 17.47it/s]
100%|███████████████████████████████████████████| 49/49 [00:01<00:00, 45.80it/s]
Task 3, Epoch 1: Accuracy Train = 98.59%
Epoch 1: 49it [00:01, 45.09it/s]
100%|███████████████████████████████████████████| 49/49 [00:00<00:00, 67.70it/s]
Task 3, Epoch 2: Accuracy Train = 99.68%
Epoch 2: 49it [00:00, 51.68it/s]
100%|███████████████████████████████████████████| 49/49 [00:00<00:00, 68.11it/s]
Task 3, Epoch 3: Accuracy Train = 99.74%
Epoch 3: 49it [00:01, 47.47it/s]
100%|███████████████████████████████████████████| 49/49 [00:00<00:00, 65.84it/s]
Task 3, Epoch 4: Accuracy Train = 100.00%
Epoch 4: 49it [00:01, 44.71it/s]
100%|███████████████████████████████████████████| 49/49 [00:00<00:00, 65.98it/s]
Task 3, Epoch 5: Accuracy Train = 100.00%
100%|█████████████████████████████████████████| 144/144 [00:02<00:00, 54.06it/s]
Task 3: Accuracy Test = 12.85%
Starting Task 4 - Training on classes: [24, 25, 26, 27, 28, 29, 30, 31]
Epoch 0: 50it [00:02, 18.34it/s]
100%|███████████████████████████████████████████| 50/50 [00:01<00:00, 46.39it/s]
Task 4, Epoch 1: Accuracy Train = 96.01%
Epoch 1: 50it [00:01, 44.43it/s]
100%|███████████████████████████████████████████| 50/50 [00:00<00:00, 67.56it/s]
Task 4, Epoch 2: Accuracy Train = 99.37%
Epoch 2: 50it [00:00, 57.53it/s]
100%|███████████████████████████████████████████| 50/50 [00:00<00:00, 69.38it/s]
Task 4, Epoch 3: Accuracy Train = 99.87%
Epoch 3: 50it [00:00, 53.33it/s]
100%|███████████████████████████████████████████| 50/50 [00:00<00:00, 66.46it/s]
Task 4, Epoch 4: Accuracy Train = 99.91%
Epoch 4: 50it [00:01, 47.69it/s]
100%|███████████████████████████████████████████| 50/50 [00:00<00:00, 63.53it/s]
Task 4, Epoch 5: Accuracy Train = 100.00%
100%|█████████████████████████████████████████| 167/167 [00:03<00:00, 53.80it/s]
Task 4: Accuracy Test = 11.76%
Starting Task 5 - Training on classes: [32, 33, 34, 35, 36, 37, 38, 39]
Epoch 0: 60it [00:03, 18.13it/s]
100%|███████████████████████████████████████████| 60/60 [00:01<00:00, 47.98it/s]
Task 5, Epoch 1: Accuracy Train = 99.23%
Epoch 1: 60it [00:01, 50.25it/s]
100%|███████████████████████████████████████████| 60/60 [00:00<00:00, 70.68it/s]
Task 5, Epoch 2: Accuracy Train = 99.87%
Epoch 2: 60it [00:01, 47.79it/s]
100%|███████████████████████████████████████████| 60/60 [00:00<00:00, 65.88it/s]
Task 5, Epoch 3: Accuracy Train = 99.92%
Epoch 3: 60it [00:01, 45.29it/s]
100%|███████████████████████████████████████████| 60/60 [00:00<00:00, 66.15it/s]
Task 5, Epoch 4: Accuracy Train = 99.97%
Epoch 4: 60it [00:01, 46.30it/s]
100%|███████████████████████████████████████████| 60/60 [00:00<00:00, 66.43it/s]
Task 5, Epoch 5: Accuracy Train = 100.00%
100%|█████████████████████████████████████████| 194/194 [00:03<00:00, 50.21it/s]
Task 5: Accuracy Test = 12.87%



Run history:


incremental_accuracy█▄▁▁▁
task▁▃▅▆█
train/train_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▂▁▁▁▁▁▁▁▁▁▁▁


Run summary:


incremental_accuracy12.87328
task4
train/train_loss0.02354


View run Fine tuning at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL/runs/252cfjp6
View project at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20250322_050856-252cfjp6/logs

Section 8: Upper bound¶

In [25]:
#################################################
## Global upper bound (all data, all classes)
####

# One task + all classes computed using 5 epochs

model = copy.deepcopy(copy_model)
accu = incremental_learning(model, train_dataset, train_target, test_dataset, test_target,
                      1, (numclasses // num_tasks) * num_tasks, batch_size, 5, lr, device)

print(f"!!!!! Upper bound of accuracy = {accu[0]:.2f}%")
Starting Task 1 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
Epoch 0: 100%|████████████████████████████████| 407/407 [00:05<00:00, 68.49it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 66.28it/s]
Task 1, Epoch 1: Accuracy Train = 96.95%
Epoch 1: 100%|████████████████████████████████| 407/407 [00:05<00:00, 69.75it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 67.04it/s]
Task 1, Epoch 2: Accuracy Train = 98.92%
Epoch 2: 100%|████████████████████████████████| 407/407 [00:05<00:00, 70.56it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 67.61it/s]
Task 1, Epoch 3: Accuracy Train = 99.44%
Epoch 3: 100%|████████████████████████████████| 407/407 [00:05<00:00, 70.86it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 65.29it/s]
Task 1, Epoch 4: Accuracy Train = 99.64%
Epoch 4: 100%|████████████████████████████████| 407/407 [00:05<00:00, 70.10it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 66.00it/s]
Task 1, Epoch 5: Accuracy Train = 99.85%
100%|█████████████████████████████████████████| 194/194 [00:03<00:00, 56.78it/s]
Task 1: Accuracy Test = 91.00%
!!!!! Upper bound of accuracy = 91.00%

In [26]:
########################################
## Upper bound for each task (takes some time)
####
# initialise a wandb run
num_epochs = 5

run = wandb.init(
    project="GTSRB-CIL",
    name = "Upper bound",
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "num_tasks": num_tasks,
        "classes_per_task": classes_per_task,
        "lr": lr,
        })

# Non incremental data (learn all classes from all data for each task)
model = copy.deepcopy(copy_model)
incremental_learning_wandb(model, train_dataset, train_target, test_dataset, test_target,
                      num_tasks, classes_per_task, batch_size, num_epochs, lr, device, non_incremental = True)

wandb.finish()
Tracking run with wandb version 0.19.8
Run data is saved locally in /home/home/projects/tensor/projet_Machine_learning/wandb/run-20250322_051144-fy9f46ci
Syncing run Upper bound to Weights & Biases (docs)
View project at https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
View run at https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL/runs/fy9f46ci
Starting Task 1 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7]
Epoch 0: 125it [00:02, 50.60it/s]
100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 70.31it/s]
Task 1, Epoch 1: Accuracy Train = 66.49%
Epoch 1: 125it [00:02, 48.15it/s]
100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 68.05it/s]
Task 1, Epoch 2: Accuracy Train = 90.68%
Epoch 2: 125it [00:02, 46.76it/s]
100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 68.28it/s]
Task 1, Epoch 3: Accuracy Train = 96.42%
Epoch 3: 125it [00:02, 46.72it/s]
100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 67.39it/s]
Task 1, Epoch 4: Accuracy Train = 97.55%
Epoch 4: 125it [00:02, 46.88it/s]
100%|█████████████████████████████████████████| 125/125 [00:01<00:00, 67.46it/s]
Task 1, Epoch 5: Accuracy Train = 98.52%
100%|███████████████████████████████████████████| 61/61 [00:00<00:00, 65.34it/s]
Task 1: Accuracy Test = 89.33%
Starting Task 2 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Epoch 0: 250it [00:05, 46.46it/s]
100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 64.42it/s]
Task 2, Epoch 1: Accuracy Train = 89.45%
Epoch 1: 250it [00:05, 46.90it/s]
100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 65.07it/s]
Task 2, Epoch 2: Accuracy Train = 96.90%
Epoch 2: 250it [00:05, 46.84it/s]
100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 65.78it/s]
Task 2, Epoch 3: Accuracy Train = 98.82%
Epoch 3: 250it [00:05, 47.55it/s]
100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 65.07it/s]
Task 2, Epoch 4: Accuracy Train = 99.62%
Epoch 4: 250it [00:05, 46.15it/s]
100%|█████████████████████████████████████████| 250/250 [00:03<00:00, 65.77it/s]
Task 2, Epoch 5: Accuracy Train = 99.62%
100%|█████████████████████████████████████████| 122/122 [00:01<00:00, 64.16it/s]
Task 2: Accuracy Test = 92.43%
Starting Task 3 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
Epoch 0: 299it [00:06, 46.58it/s]
100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 64.41it/s]
Task 3, Epoch 1: Accuracy Train = 93.37%
Epoch 1: 299it [00:06, 45.74it/s]
100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 60.77it/s]
Task 3, Epoch 2: Accuracy Train = 98.53%
Epoch 2: 299it [00:06, 44.30it/s]
100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 60.26it/s]
Task 3, Epoch 3: Accuracy Train = 98.52%
Epoch 3: 299it [00:06, 49.09it/s]
100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 61.96it/s]
Task 3, Epoch 4: Accuracy Train = 99.34%
Epoch 4: 299it [00:06, 44.59it/s]
100%|█████████████████████████████████████████| 299/299 [00:04<00:00, 63.11it/s]
Task 3, Epoch 5: Accuracy Train = 99.76%
100%|█████████████████████████████████████████| 144/144 [00:02<00:00, 52.74it/s]
Task 3: Accuracy Test = 91.18%
Starting Task 4 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
Epoch 0: 348it [00:07, 47.85it/s]
100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 61.79it/s]
Task 4, Epoch 1: Accuracy Train = 93.79%
Epoch 1: 348it [00:07, 49.44it/s]
100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 63.44it/s]
Task 4, Epoch 2: Accuracy Train = 98.32%
Epoch 2: 348it [00:07, 45.74it/s]
100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 63.26it/s]
Task 4, Epoch 3: Accuracy Train = 99.02%
Epoch 3: 348it [00:07, 45.41it/s]
100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 63.19it/s]
Task 4, Epoch 4: Accuracy Train = 99.20%
Epoch 4: 348it [00:07, 45.91it/s]
100%|█████████████████████████████████████████| 348/348 [00:05<00:00, 63.37it/s]
Task 4, Epoch 5: Accuracy Train = 99.61%
100%|█████████████████████████████████████████| 167/167 [00:03<00:00, 48.42it/s]
Task 4: Accuracy Test = 88.33%
Starting Task 5 - Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
Epoch 0: 407it [00:08, 47.51it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 60.86it/s]
Task 5, Epoch 1: Accuracy Train = 93.77%
Epoch 1: 407it [00:08, 47.60it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 60.09it/s]
Task 5, Epoch 2: Accuracy Train = 98.69%
Epoch 2: 407it [00:08, 47.49it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 61.13it/s]
Task 5, Epoch 3: Accuracy Train = 99.05%
Epoch 3: 407it [00:08, 49.19it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 61.79it/s]
Task 5, Epoch 4: Accuracy Train = 99.68%
Epoch 4: 407it [00:08, 47.85it/s]
100%|█████████████████████████████████████████| 407/407 [00:06<00:00, 61.82it/s]
Task 5, Epoch 5: Accuracy Train = 99.52%
100%|█████████████████████████████████████████| 194/194 [00:04<00:00, 47.40it/s]
Task 5: Accuracy Test = 89.81%



Run history:


incremental_accuracy▃█▆▁▄
task▁▃▅▆█
train/train_loss▅▂▂▂▁▂▂▁▁▁▁▅▂▁▁▁▁▁▁█▂▁▁▁▁▁▁▁▁▁▃▂▂▁▁▁▁▁▁▁


Run summary:


incremental_accuracy89.81437
task4
train/train_loss0.00503


View run Upper bound at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL/runs/fy9f46ci
View project at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20250322_051144-fy9f46ci/logs

Section 9: Rehearsal Memory Buffer + Knowledge Distillation Loss (KD)¶

Knowledge Distillation :

$$ \mathcal{L} = \underbrace{\ell(f(\mathbf{x}), y)}_{\text{Learning New Classes}} + \underbrace{\sum_{k=1}^{|\mathcal{Y}_{b-1}|} -S_k(f^{b-1}(\mathbf{x})) \log S_k(f(\mathbf{x})) }_{\text{Remembering Old Classes}} $$

  • $\ell(f(x), y)$: Standard cross-entropy loss for learning new classes.

  • Second term: Knowledge Distillation, encouraging the new model's output $f(x)$ to match the old model's output $f^{b-1}(x)$.

  • $S_k$: Softmax function applied to logits.

A version with temperature-scaled softmax :

$$ \mathcal{L} = (1-\lambda) \underbrace{\ell(f(\mathbf{x}), y)}_{\text{Learning New Classes}} + \lambda \underbrace{ \sum_{k=1}^{|\mathcal{Y}_{b-1}|} -S_k\left( \frac{f^{b-1}(\mathbf{x})}{T} \right) \log S_k\left( \frac{f(\mathbf{x})}{T} \right) }_{\text{Remembering Old Classes (KD with Temperature $T$)}} $$

  • $\ell(f(x), y)$: Cross-entropy loss for the current new classes. $\leadsto$   loss_cls = criterion(logits, targets)

  • The second term uses softmax temperature scaling:

    $$ S_k\left( \frac{f(x)}{T} \right) = \frac{\exp\left( f_k(x) / T \right)}{\sum_{j} \exp\left( f_j(x) / T \right)}, $$

    where $T$ is the temperature hyperparameter, typically $T > 1$ to soften the distribution.

    $\leadsto$   loss_kd = KD_loss(logits[:, :len(current_classes)-classes_per_task], old_logits, T)

  • $\lambda = \frac{{|\mathcal{Y}_{b-1}|}}{{|\mathcal{Y}_{b}|}}$ : A balancing coefficient, denotes the proportion of old classes among all classes..

In [27]:
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset, Subset

import torch.nn.functional as F

def KD_loss(student_logits, teacher_logits, T=2):
    student_log_prob = F.log_softmax(student_logits / T, dim=1)
    teacher_prob = F.softmax(teacher_logits / T, dim=1)
    return F.kl_div(student_log_prob, teacher_prob, reduction='batchmean') * (T * T)

Rehearsal Memory Buffer

$$ \mathcal{L} = \sum_{(\mathbf{x}, y) \in (\mathcal{D}^{b} \cup \mathcal{E})} \ell(f(\mathbf{x}), y). $$

  • $\mathcal{D}^{b}$ represents the data of the current task

  • $\mathcal{E}$ denotes the exemplar set (Rehearsal Buffer) containing selected samples from previous tasks.

During training, the model jointly uses the current task data $\mathcal{D}^{b}$ and the stored exemplars $\mathcal{E}$, helping the model to retain knowledge from old tasks while learning new ones.

$\leadsto$   combined_dataset = ConcatDataset([current_subset, memory_dataset])

In [28]:
def incremental_learning_rehearsal_KD_wandb(model, train_dataset, train_target, test_dataset, test_target,
                         num_tasks, classes_per_task, batch_size, num_epochs, lr, device):
    from collections import deque

    nclasses = len(np.unique(train_target))
    all_classes = list(range(nclasses))
    criterion = nn.CrossEntropyLoss()
    current_classes = []
    accuracies = []

    # Replay Buffer 复现缓冲区
    total_mem_size = 200  # 缓冲区总大小 / Total memory buffer size
    memory_data = []      # 存储旧样本数据 / Store exemplar data
    memory_labels = []    # 存储旧样本标签 / Store exemplar labels

    old_model = None
    T = 2   # KD 温度系数 / Temperature for Knowledge Distillation

    wandb.define_metric("task")
    wandb.define_metric("incremental_accuracy", step_metric="task")

    for task in range(num_tasks):
        # 当前任务类别 / Select current task classes
        task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
        current_classes.extend(task_classes)

        # === 新任务数据子集 / Create subset for new task ===
        '''
        train_indices = [i for i, label in enumerate(train_target) if label in task_classes]
        current_subset = Subset(train_dataset, train_indices)
        '''
        train_indices = [i for i, label in enumerate(train_target) if label in task_classes]

        # 提取子集数据 & 标签,并转换为 TensorDataset / Extract subset data & labels, convert to TensorDataset
        subset_data = []
        subset_labels = []

        for idx in train_indices:
            img, label = train_dataset[idx]
            subset_data.append(img)
            subset_labels.append(label)

        subset_data = torch.stack(subset_data)
        subset_labels = torch.tensor(subset_labels, dtype=torch.long)

        current_subset = TensorDataset(subset_data, subset_labels)


        # === Memory Buffer 数据集 / Memory buffer dataset ===
        if memory_data:
            '''
            mem_tensor_data = torch.stack(memory_data)
            mem_tensor_labels = torch.tensor(memory_labels)
            memory_dataset = TensorDataset(mem_tensor_data, mem_tensor_labels)
            combined_dataset = ConcatDataset([current_subset, memory_dataset])'
            '''
            mem_tensor_data = torch.stack(memory_data)
            mem_tensor_labels = torch.tensor(memory_labels, dtype=torch.long)  # 指定 long 类型
            memory_dataset = TensorDataset(mem_tensor_data, mem_tensor_labels)
            combined_dataset = ConcatDataset([current_subset, memory_dataset])
        else:
            combined_dataset = current_subset

        # 创建数据加载器 / Create DataLoaders
        train_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)
        test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle=False)

        # === 扩展输出层 / Expand output layer for new classes ===
        if task == 0:
            model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
        else:
            old_weight = model.fc.weight.data.clone()
            old_bias = model.fc.bias.data.clone()
            model.fc = nn.Linear(model.fc.in_features, len(current_classes)).to(device)
            model.fc.weight.data[:len(old_weight)] = old_weight
            model.fc.bias.data[:len(old_bias)] = old_bias

        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

        print(f"Task {task+1}: Training on classes {task_classes}")


        # === 训练模型 / Train model ===
        for epoch in range(num_epochs):
            model.train()
            running_loss, correct, total = 0.0, 0, 0
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                logits = model(inputs)
                loss_cls = criterion(logits, targets)

                # === Knowledge Distillation 旧模型引导 / KD loss with old model ===
                if old_model:
                    with torch.no_grad():
                        old_logits = old_model(inputs)
                    loss_kd = KD_loss(logits[:, :len(current_classes)-classes_per_task], old_logits, T)
                    loss = loss_cls + loss_kd
                else:
                    loss = loss_cls

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                preds = logits.argmax(dim=1)
                correct += (preds == targets).sum().item()
                total += targets.size(0)

            scheduler.step()
            train_acc = correct / total * 100
            print(f"Task {task+1}, Epoch {epoch+1}: Loss {running_loss:.3f}, Train Acc {train_acc:.2f}%")

            wandb.log({
                "train/train_loss": running_loss,
                "train/train_accuracy": train_acc,
                "epoch": epoch,
                "task": task
            })

        # === 测试模型 / Evaluate on test set ===
        test_acc = evaluate(model, test_loader, device)
        accuracies.append(test_acc)
        print(f"Task {task+1}: Test Acc = {test_acc:.2f}%")
        wandb.log({"incremental_accuracy": test_acc, "task": task})

        # === 更新 Memory Buffer / Update memory buffer ===
        mem_per_class = total_mem_size // len(current_classes)      # 每类分配样本数 / Exemplars per class
        memory_data = []
        memory_labels = []
        for cls in current_classes:
            indices = [i for i, label in enumerate(train_target) if label == cls]
            selected = np.random.choice(indices, min(mem_per_class, len(indices)), replace=False)
            for idx in selected:
                img = train_dataset[idx][0]
                memory_data.append(img)
                memory_labels.append(cls)

        print(f"Memory Buffer: {len(memory_labels)} samples total")

        # 保存当前模型 / Store old model
        old_model = copy.deepcopy(model).to(device).eval()

    return accuracies
In [29]:
num_epochs = 5

run = wandb.init(
    project="GTSRB-CIL",
    name="Rehearsal_KD",
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "num_tasks": num_tasks,
        "classes_per_task": classes_per_task,
        "lr": lr,
        "rehearsal_buffer_size": 200,
        "kd_temperature": 2
    })
Tracking run with wandb version 0.19.8
Run data is saved locally in /home/home/projects/tensor/projet_Machine_learning/wandb/run-20250322_051625-s2fkx109
Syncing run Rehearsal_KD to Weights & Biases (docs)
View project at https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
View run at https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL/runs/s2fkx109
In [30]:
model = copy.deepcopy(copy_model)  # 用预训练模型开始

incremental_learning_rehearsal_KD_wandb(
    model, 
    train_dataset, 
    train_target, 
    test_dataset, 
    test_target,
    num_tasks, 
    classes_per_task, 
    batch_size, 
    num_epochs, 
    lr, 
    device
)

wandb.finish()
Task 1: Training on classes [0, 1, 2, 3, 4, 5, 6, 7]
Task 1, Epoch 1: Loss 96.120, Train Acc 76.68%
Task 1, Epoch 2: Loss 34.481, Train Acc 93.79%
Task 1, Epoch 3: Loss 24.019, Train Acc 95.57%
Task 1, Epoch 4: Loss 18.945, Train Acc 96.63%
Task 1, Epoch 5: Loss 15.440, Train Acc 97.45%
100%|███████████████████████████████████████████| 61/61 [00:00<00:00, 71.29it/s]
Task 1: Test Acc = 90.21%
Memory Buffer: 200 samples total
Task 2: Training on classes [8, 9, 10, 11, 12, 13, 14, 15]
Task 2, Epoch 1: Loss 108.750, Train Acc 87.26%
Task 2, Epoch 2: Loss 20.895, Train Acc 96.93%
Task 2, Epoch 3: Loss 14.095, Train Acc 97.84%
Task 2, Epoch 4: Loss 11.093, Train Acc 98.43%
Task 2, Epoch 5: Loss 9.449, Train Acc 98.65%
100%|█████████████████████████████████████████| 122/122 [00:02<00:00, 56.74it/s]
Task 2: Test Acc = 74.71%
Memory Buffer: 192 samples total
Task 3: Training on classes [16, 17, 18, 19, 20, 21, 22, 23]
Task 3, Epoch 1: Loss 113.056, Train Acc 73.07%
Task 3, Epoch 2: Loss 20.128, Train Acc 93.99%
Task 3, Epoch 3: Loss 10.361, Train Acc 96.59%
Task 3, Epoch 4: Loss 7.401, Train Acc 97.68%
Task 3, Epoch 5: Loss 5.821, Train Acc 98.31%
100%|█████████████████████████████████████████| 144/144 [00:02<00:00, 62.60it/s]
Task 3: Test Acc = 68.27%
Memory Buffer: 192 samples total
Task 4: Training on classes [24, 25, 26, 27, 28, 29, 30, 31]
Task 4, Epoch 1: Loss 131.107, Train Acc 63.29%
Task 4, Epoch 2: Loss 22.356, Train Acc 92.65%
Task 4, Epoch 3: Loss 13.884, Train Acc 95.91%
Task 4, Epoch 4: Loss 10.675, Train Acc 97.00%
Task 4, Epoch 5: Loss 8.928, Train Acc 97.51%
100%|█████████████████████████████████████████| 167/167 [00:02<00:00, 66.06it/s]
Task 4: Test Acc = 56.02%
Memory Buffer: 192 samples total
Task 5: Training on classes [32, 33, 34, 35, 36, 37, 38, 39]
Task 5, Epoch 1: Loss 117.756, Train Acc 76.96%
Task 5, Epoch 2: Loss 15.510, Train Acc 95.87%
Task 5, Epoch 3: Loss 9.946, Train Acc 97.18%
Task 5, Epoch 4: Loss 7.597, Train Acc 97.96%
Task 5, Epoch 5: Loss 6.423, Train Acc 98.29%
100%|█████████████████████████████████████████| 194/194 [00:02<00:00, 67.68it/s]
Task 5: Test Acc = 56.77%
Memory Buffer: 200 samples total


Run history:


epoch▁▃▅▆█▁▃▅▆█▁▃▅▆█▁▃▅▆█▁▃▅▆█
incremental_accuracy█▅▄▁▁
task▁▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆██████
train/train_accuracy▄▇▇██▆████▃▇███▁▇▇██▄▇███
train/train_loss▆▃▂▂▂▇▂▁▁▁▇▂▁▁▁█▂▁▁▁▇▂▁▁▁


Run summary:


epoch4
incremental_accuracy56.77159
task4
train/train_accuracy98.28802
train/train_loss6.42265


View run Rehearsal_KD at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL/runs/s2fkx109
View project at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20250322_051625-s2fkx109/logs

Section 10: Dynamically Expandable Representation (DER)¶

In [31]:
from collections import Counter

# 统计每个类别标签的数量,并按数量降序排列
# Count occurrences of each label, sort by most common
label_counts = Counter(train_target).most_common()

# 打印每个类别的统计结果
# Print statistics for each class
print("Count\tClassID\tClassName")
for l, c in label_counts:
    print(f"{c}\t{l}\t{class_names[l]}")  # 输出格式: 数量, 类别编号, 类别名称
Count	ClassID	ClassName
1500	1	Speed limit (30km/h)
1500	2	Speed limit (50km/h)
1440	13	Yield
1410	12	Priority road
1380	38	Keep right
1350	10	No passing for vechiles over 3.5 metric tons
1320	4	Speed limit (70km/h)
1260	5	Speed limit (80km/h)
1020	25	Road work
990	9	No passing
960	3	Speed limit (60km/h)
960	7	Speed limit (100km/h)
960	8	Speed limit (120km/h)
900	11	Right-of-way at the next intersection
810	18	General caution
810	35	Ahead only
750	17	No entry
540	14	Stop
540	31	Wild animals crossing
480	33	Turn right ahead
420	15	No vechiles
420	26	Traffic signals
360	23	Slippery road
360	28	Children crossing
300	6	End of speed limit (80km/h)
300	16	Vechiles over 3.5 metric tons prohibited
300	30	Beware of ice/snow
300	34	Turn left ahead
270	22	Bumpy road
270	36	Go straight or right
240	20	Dangerous curve to the right
240	21	Double curve
240	40	Roundabout mandatory
210	39	Keep left
180	24	Road narrows on the right
180	27	Pedestrians
180	29	Bicycles crossing
180	32	End of all speed and passing limits
180	41	End of no passing
180	42	End of no passing by vechiles over 3.5 metric tons
150	0	Speed limit (20km/h)
150	19	Dangerous curve to the left
150	37	Go straight or left

$\leadsto$   Remark 2

  • Buffer Size Strategy :

    Section 9: total_mem_size = 200, mem_per_class = total_mem_size // len(current_classes) The total buffer size is fixed at 200 samples and After each task, the buffer is evenly divided among all seen classes.

    For example:
    Task 1: 200/8 = 25 samples per class
    Task 5: 200/40 = 5 samples per class

    Eventually, old classes have very few samples, leading to severe forgetting in later stages.

    Section 10: mem_per_class = 10 The buffer keeps a fixed number of 10 samples per class.

    For example:
    Task 1: Seen Classes = 8 | Buffer Size = 80
    Task 5: Seen Classes = 40 | Buffer Size = 400

  • How Buffer Is Used During Training :

    Section 9: combined_dataset = ConcatDataset([current_task_data, memory_dataset]) Simply concatenates new class data and buffer samples $\leadsto$ The amount of new class data is much larger than old class data

    section 10: new_loader = DataLoader(new_data, batch_size/2), mem_loader = DataLoader(memory_data, batch_size/2) Each batch contains 50% new class + 50% buffer samples $\leadsto$ Balanced Sampling

  • Bias Correction :

    Section 9: No additional Bias Correction step.

    section 10: buffer_loader = DataLoader(memory_dataset, batch_size), for epoch in range(1):``for inputs, targets in buffer_loader: Fine-tune only on buffer samples $\leadsto$ Helps to preserve the decision boundaries of old classes

In [32]:
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader, ncols=80):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            if isinstance(outputs, tuple):  # DER模型,取第一个主分类器输出
                outputs = outputs[0]
                
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

Backbone Expansion

$$ f(x) = W_{\text{new}}^{\top} [\phi_{\text{old}}(x), \phi_{\text{new}}(x)] $$

  • $\phi_{\text{old}}(x)$: Feature representation from the old (previous) backbone (frozen after training).
  • $\phi_{\text{new}}(x)$: Feature representation extracted by the new backbone for the current task.
  • $[\cdot, \cdot]$: Feature concatenation operation.
  • $W_{\text{new}}$: Fully connected layer applied after feature aggregation, dynamically expanded.

$\leadsto$   features = F.normalize(self.feature_fc(x), p=2, dim=1), out = self.classifier(features) # W_new applied to features, aux_out = self.aux_classifier(features) # Auxiliary classifier

Loss Function (Cross-task classification)

$$ \mathcal{L} = \sum_{k=1}^{|\mathcal{Y}_b|} -\mathbb{I}(y = k) \log S_k\left( W_{\text{new}}^{\top} [\bar{\phi}_{\text{old}}(x), \phi_{\text{new}}(x)] \right) $$

  • $\bar{\phi}_{\text{old}}(x)$: Indicates the frozen old backbone.
  • $\mathcal{Y}_b$: The set of all classes learned up to the current task.
  • $S_k(\cdot)$: Softmax function applied to compute the probability of class $k$.

Additionally, DER employs an \textbf{auxiliary loss} to encourage better differentiation between old and new classes (optional).

$\leadsto$   def expand_output(self, n_new_classes):, old_weight = self.classifier.weight.data.clone(), self.classifier = nn.Linear(..., total_classes)

In [33]:
import torch.nn as nn
import torch.nn.functional as F

class DER_CNN(nn.Module):
    def __init__(self, n_in=3, base_feature_dim=128):
        super(DER_CNN, self).__init__()

        # Shared feature extractor
        self.conv1 = nn.Conv2d(n_in, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

        self.feature_fc = nn.Linear(64 * 8 * 8, base_feature_dim)

        # Dynamically expandable classifier
        self.classifier = nn.Linear(base_feature_dim, 0)  # start with 0 output
        self.aux_classifier = nn.Linear(base_feature_dim, 0)  # auxiliary classifier

    def expand_output(self, n_new_classes):
        # Main classifier
        old_weight = self.classifier.weight.data.clone() if self.classifier.out_features > 0 else None
        old_bias = self.classifier.bias.data.clone() if self.classifier.out_features > 0 else None
        total_classes = self.classifier.out_features + n_new_classes
        self.classifier = nn.Linear(self.feature_fc.out_features, total_classes).to(self.feature_fc.weight.device)
        if old_weight is not None:
            self.classifier.weight.data[:old_weight.size(0)] = old_weight
            self.classifier.bias.data[:old_bias.size(0)] = old_bias

        # Auxiliary classifier
        old_aux_weight = self.aux_classifier.weight.data.clone() if self.aux_classifier.out_features > 0 else None
        old_aux_bias = self.aux_classifier.bias.data.clone() if self.aux_classifier.out_features > 0 else None
        self.aux_classifier = nn.Linear(self.feature_fc.out_features, total_classes).to(self.feature_fc.weight.device)
        if old_aux_weight is not None:
            self.aux_classifier.weight.data[:old_aux_weight.size(0)] = old_aux_weight
            self.aux_classifier.bias.data[:old_aux_bias.size(0)] = old_aux_bias

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        features = self.relu(self.feature_fc(x))
        features = F.normalize(features, p=2, dim=1)  # Add feature normalization
        out = self.classifier(features)
        aux_out = self.aux_classifier(features)
        return out, aux_out, features
In [34]:
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler

def incremental_learning_DER(model, train_dataset, train_target, test_dataset, test_target,
                         num_tasks, classes_per_task, batch_size, num_epochs, lr, device):

    nclasses = len(np.unique(train_target))
    all_classes = list(range(nclasses))
    criterion = nn.CrossEntropyLoss()
    current_classes = []
    accuracies = []

    # Fixed per class buffer
    mem_per_class = 10
    memory_data = []
    memory_labels = []

    old_model = None
    T = 2  # KD温度
    alpha = 0.7  # Aux loss weight
    beta = 1.5   # KD loss weight

    wandb.define_metric("task")
    wandb.define_metric("incremental_accuracy", step_metric="task")

    for task in range(num_tasks):
        task_classes = all_classes[task * classes_per_task : (task + 1) * classes_per_task]
        current_classes.extend(task_classes)

        model.expand_output(len(task_classes))

        # Prepare new class data
        train_indices = [i for i, label in enumerate(train_target) if label in task_classes]
        subset_data, subset_labels = [], []
        for idx in train_indices:
            img, label = train_dataset[idx]
            subset_data.append(img)
            subset_labels.append(label)
        subset_data = torch.stack(subset_data)
        subset_labels = torch.tensor(subset_labels, dtype=torch.long)
        current_subset = TensorDataset(subset_data, subset_labels)

        # Add buffer data
        if memory_data:
            mem_tensor_data = torch.stack(memory_data)
            mem_tensor_labels = torch.tensor(memory_labels, dtype=torch.long)
            memory_dataset = TensorDataset(mem_tensor_data, mem_tensor_labels)
        else:
            memory_dataset = None

        test_loader = create_dataloader(test_dataset, test_target, current_classes, batch_size, shuffle=False)

        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.3)

        print(f"Task {task+1}: Training on classes {task_classes}")

        # -------- Balanced Sampling ---------
        new_loader = DataLoader(current_subset, batch_size=batch_size//2, shuffle=True)
        if memory_dataset:
            mem_loader = DataLoader(memory_dataset, batch_size=batch_size//2, shuffle=True)
        else:
            mem_loader = None

        for epoch in range(num_epochs):
            model.train()
            running_loss, correct, total = 0.0, 0, 0

            new_iter = iter(new_loader)
            mem_iter = iter(mem_loader) if mem_loader else None
            steps = len(new_loader)

            for step in range(steps):
                try:
                    new_inputs, new_targets = next(new_iter)
                except StopIteration:
                    break

                if mem_iter:
                    try:
                        mem_inputs, mem_targets = next(mem_iter)
                    except StopIteration:
                        mem_iter = iter(mem_loader)
                        mem_inputs, mem_targets = next(mem_iter)
                    inputs = torch.cat([new_inputs, mem_inputs], dim=0).to(device)
                    targets = torch.cat([new_targets, mem_targets], dim=0).to(device)
                else:
                    inputs, targets = new_inputs.to(device), new_targets.to(device)

                logits, aux_logits, _ = model(inputs)
                loss_cls = criterion(logits, targets)
                loss_aux = criterion(aux_logits, targets)

                if old_model:
                    with torch.no_grad():
                        old_logits, _, _ = old_model(inputs)
                    kd_loss = KD_loss(logits[:, :len(current_classes)-classes_per_task], old_logits, T)
                    loss = loss_cls + alpha * loss_aux + beta * kd_loss
                else:
                    loss = loss_cls + alpha * loss_aux

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                preds = logits.argmax(dim=1)
                correct += (preds == targets).sum().item()
                total += targets.size(0)

            scheduler.step()
            train_acc = correct / total * 100
            print(f"Task {task+1}, Epoch {epoch+1}: Loss {running_loss:.3f}, Train Acc {train_acc:.2f}%")

            wandb.log({
                "train/train_loss": running_loss,
                "train/train_accuracy": train_acc,
                "epoch": epoch,
                "task": task
            })

        # --------- Evaluate ---------
        test_acc = evaluate(model, test_loader, device)
        accuracies.append(test_acc)
        print(f"Task {task+1}: Test Acc = {test_acc:.2f}%")
        wandb.log({"incremental_accuracy": test_acc, "task": task})

        # --------- Update Buffer (Fixed per class) ---------
        memory_data = []
        memory_labels = []
        for cls in current_classes:
            indices = [i for i, label in enumerate(train_target) if label == cls]
            selected = np.random.choice(indices, min(mem_per_class, len(indices)), replace=False)
            for idx in selected:
                img = train_dataset[idx][0]
                memory_data.append(img)
                memory_labels.append(cls)

        print(f"Memory Buffer updated: {len(memory_labels)} samples")

        # --------- Bias Correction Step (Fine-tune on buffer) ---------
        if memory_dataset:
            buffer_loader = DataLoader(memory_dataset, batch_size=batch_size, shuffle=True)
            for epoch in range(1):
                for inputs, targets in buffer_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    logits, aux_logits, _ = model(inputs)
                    loss_cls = criterion(logits, targets)
                    optimizer.zero_grad()
                    loss_cls.backward()
                    optimizer.step()
            print("Bias correction fine-tuning done.")

        # Copy model for KD
        old_model = copy.deepcopy(model).to(device).eval()

    return accuracies
In [35]:
num_epochs = 5

run = wandb.init(
    project="GTSRB-CIL",
    name="DER",
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "num_tasks": num_tasks,
        "classes_per_task": classes_per_task,
        "lr": lr,
        "rehearsal_buffer_size": 10,
        "kd_temperature": 2
    })

# Initialize DER model
model = DER_CNN(n_in=3).to(device)

incremental_learning_DER(
    model, 
    train_dataset, 
    train_target, 
    test_dataset, 
    test_target,
    num_tasks, 
    classes_per_task, 
    batch_size, 
    num_epochs, 
    lr, 
    device
)

wandb.finish()
Tracking run with wandb version 0.19.8
Run data is saved locally in /home/home/projects/tensor/projet_Machine_learning/wandb/run-20250322_051655-es0akd2d
Syncing run DER to Weights & Biases (docs)
View project at https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
View run at https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL/runs/es0akd2d
/home/home/miniconda3/envs/tensor/lib/python3.11/site-packages/torch/nn/init.py:511: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
Task 1: Training on classes [0, 1, 2, 3, 4, 5, 6, 7]
Task 1, Epoch 1: Loss 771.271, Train Acc 31.03%
Task 1, Epoch 2: Loss 505.882, Train Acc 73.91%
Task 1, Epoch 3: Loss 258.374, Train Acc 92.98%
Task 1, Epoch 4: Loss 153.143, Train Acc 96.70%
Task 1, Epoch 5: Loss 123.441, Train Acc 97.35%
100%|███████████████████████████████████████████| 61/61 [00:00<00:00, 73.79it/s]
Task 1: Test Acc = 90.16%
Memory Buffer updated: 80 samples
Task 2: Training on classes [8, 9, 10, 11, 12, 13, 14, 15]
Task 2, Epoch 1: Loss 660.366, Train Acc 61.43%
Task 2, Epoch 2: Loss 386.230, Train Acc 95.53%
Task 2, Epoch 3: Loss 227.683, Train Acc 99.76%
Task 2, Epoch 4: Loss 164.217, Train Acc 99.97%
Task 2, Epoch 5: Loss 144.893, Train Acc 99.98%
100%|█████████████████████████████████████████| 122/122 [00:01<00:00, 68.84it/s]
Task 2: Test Acc = 90.57%
Memory Buffer updated: 160 samples
Bias correction fine-tuning done.
Task 3: Training on classes [16, 17, 18, 19, 20, 21, 22, 23]
Task 3, Epoch 1: Loss 257.001, Train Acc 60.10%
Task 3, Epoch 2: Loss 203.053, Train Acc 78.44%
Task 3, Epoch 3: Loss 166.570, Train Acc 85.33%
Task 3, Epoch 4: Loss 145.274, Train Acc 89.21%
Task 3, Epoch 5: Loss 136.623, Train Acc 91.59%
100%|█████████████████████████████████████████| 144/144 [00:02<00:00, 69.29it/s]
Task 3: Test Acc = 85.74%
Memory Buffer updated: 240 samples
Bias correction fine-tuning done.
Task 4: Training on classes [24, 25, 26, 27, 28, 29, 30, 31]
Task 4, Epoch 1: Loss 319.231, Train Acc 55.37%
Task 4, Epoch 2: Loss 264.634, Train Acc 66.18%
Task 4, Epoch 3: Loss 223.125, Train Acc 78.14%
Task 4, Epoch 4: Loss 196.497, Train Acc 88.24%
Task 4, Epoch 5: Loss 185.304, Train Acc 89.82%
100%|█████████████████████████████████████████| 167/167 [00:02<00:00, 67.98it/s]
Task 4: Test Acc = 77.97%
Memory Buffer updated: 320 samples
Bias correction fine-tuning done.
Task 5: Training on classes [32, 33, 34, 35, 36, 37, 38, 39]
Task 5, Epoch 1: Loss 385.784, Train Acc 52.56%
Task 5, Epoch 2: Loss 299.874, Train Acc 74.95%
Task 5, Epoch 3: Loss 237.962, Train Acc 85.04%
Task 5, Epoch 4: Loss 203.524, Train Acc 86.58%
Task 5, Epoch 5: Loss 190.295, Train Acc 90.58%
100%|█████████████████████████████████████████| 194/194 [00:02<00:00, 67.01it/s]
Task 5: Test Acc = 80.09%
Memory Buffer updated: 400 samples
Bias correction fine-tuning done.


Run history:


epoch▁▃▅▆█▁▃▅▆█▁▃▅▆█▁▃▅▆█▁▃▅▆█
incremental_accuracy██▅▁▂
task▁▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆██████
train/train_accuracy▁▅▇██▄████▄▆▇▇▇▃▅▆▇▇▃▅▆▇▇
train/train_loss█▅▂▁▁▇▄▂▁▁▂▂▁▁▁▃▃▂▂▂▄▃▂▂▂


Run summary:


epoch4
incremental_accuracy80.08878
task4
train/train_accuracy90.57723
train/train_loss190.29529


View run DER at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL/runs/es0akd2d
View project at: https://wandb.ai/kaiyuan_xu_upsud/GTSRB-CIL
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20250322_051655-es0akd2d/logs

Remark 3

  • Upper Bound represents the best possible accuracy (all data used together).

  • DER outperforms Rehearsal_KD, thanks to dynamically expanding the model capacity (solves representation limitation).

  • Rehearsal_KD is better than Fine-tuning but still suffers from some forgetting, especially as more tasks are added.

  • Fine Tuning clearly fails in class-incremental learning scenarios, quickly forgetting previous classes.

image.png