Training › Continual Learning¶
Source path: AlphaBrain/training/continual_learning/
The Continual Learning (CL) module: algorithm base class, replay buffer, task sequences, and the training-loop entrypoint.
The top-level __init__.py re-exports CLAlgorithm and ReplayBuffer for backwards compatibility.
Top-level re-exports¶
continual_learning ¶
Continual Learning module.
Sub-packages
algorithms/ — CL algorithms (ReplayBuffer, …) and their CLAlgorithm base. datasets/ — Task sequences and per-task dataset filtering.
Top-level entry
train — Continual training loop (AlphaBrain.training.continual_learning.train.main).
Re-exports for backward compatibility (old import paths still work): ReplayBuffer ← algorithms.replay_buffer.ReplayBuffer
CLAlgorithm ¶
Bases: ABC
Interface every continual-learning algorithm must satisfy.
observe abstractmethod ¶
Called every training step with the current task batch.
Replay-based methods (ER, GEM) use this to grow/refresh memory. Regularization methods (EWC, SI) use it to accumulate importance statistics (Fisher information, path integrals, etc.).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
sample abstractmethod ¶
Return the algorithm's auxiliary artifact for this step (or None/empty to skip).
Return shape is algorithm-specific
- ER / GEM :
list[dict]— raw samples ready to be collated. - EWC / SI :
dict[str, Tensor]— per-parameter regularization terms. - LwF :
dict[str, Tensor]— teacher logits on the current batch.
The trainer dispatches on algorithm type to combine this with the current-task batch (e.g. mix-in ratio for replay, KL term for LwF).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
on_task_end abstractmethod ¶
Hook invoked after a task finishes.
Typical uses: * EWC: snapshot parameters, compute Fisher on current task's dataset. * LwF: snapshot the teacher model weights. * ER : no-op (reservoir sampling happens online).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
state_dict abstractmethod ¶
Return a JSON-serializable snapshot of the algorithm state.
This is written alongside model checkpoints so CL state survives interruption/restart across tasks.
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
load_state_dict abstractmethod ¶
ReplayBuffer ¶
Bases: CLAlgorithm
Experience Replay buffer that stores samples from past tasks.
Uses reservoir sampling to maintain a fixed-size buffer per task. During training, samples from the buffer are mixed with current task data at a configurable ratio.
ER is a task-end-populated algorithm (the buffer is filled once after each task finishes via :meth:populate_from_dataset), so the per-step :meth:observe hook is a no-op. The trainer calls :meth:populate_from_dataset directly in its task-end handler.
Usage
buffer = ReplayBuffer(buffer_size_per_task=500)
After finishing task 0:¶
buffer.populate_from_dataset(task_id=0, dataset=task0_dataset)
During task 1 training:¶
replay_samples = buffer.sample(batch_size=4) # list[dict]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
buffer_size_per_task | int | Maximum number of samples stored per task. | 500 |
seed | int | Random seed for reproducibility. | 42 |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
populate_from_dataset ¶
Store samples from a dataset into the buffer using reservoir sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
task_id | int | Identifier for the task. | required |
dataset | Dataset | Dataset to sample from (must support len and getitem). | required |
num_samples | Optional[int] | Number of samples to store. Defaults to buffer_size_per_task. | None |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
sample ¶
Sample a batch uniformly from all stored tasks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | int | Number of samples to return. | required |
Returns:
| Type | Description |
|---|---|
List[dict] | List of sample dicts. Empty list if buffer is empty. |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
sample_balanced ¶
Sample a batch with equal representation from each stored task.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | int | Number of samples to return. | required |
Returns:
| Type | Description |
|---|---|
List[dict] | List of sample dicts. Empty list if buffer is empty. |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
get_task_ids ¶
get_task_size ¶
clear ¶
state_dict ¶
Return serializable state for checkpointing.
Note: only metadata is serialized — the actual sample tensors are not saved (they can be large). On resume, callers must re-populate the buffer by iterating each task's dataset again.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
observe ¶
No-op: ER populates from the full dataset at task-end, not per step.
See :meth:populate_from_dataset for the actual memory update, which the trainer invokes from its task-end handler.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
on_task_end ¶
No-op: the trainer calls :meth:populate_from_dataset directly.
(ER needs the full task dataset object, which a generic no-arg hook cannot provide — so the trainer orchestrates the population explicitly.)
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
load_state_dict ¶
Restore metadata from a snapshot produced by :meth:state_dict.
Only hyperparameters are restored (buffer size, seed). Actual samples must be re-populated from the task datasets on resume.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
Training entrypoint¶
train ¶
Continual Learning Trainer for AlphaBrain.
Trains a VLA model sequentially on a stream of tasks, with optional Experience Replay to mitigate catastrophic forgetting.
Design: - Follows the framework convention of one trainer file per training strategy. - Reuses existing build_framework, build_dataloader, and TrainerUtils. - Adds an outer loop over tasks and integrates the replay buffer.
Config
Add a continual_learning section to your YAML:
continual_learning: task_sequence: libero_spatial # CL sequence name (see continual_learning.py) steps_per_task: 10000 # training steps per task save_checkpoint_per_task: true # save after each task
replay: enabled: true method: experience_replay # replay method (currently only ER) buffer_size_per_task: 500 # samples to store per past task replay_batch_ratio: 0.3 # fraction of each batch from replay balanced_sampling: false # equal samples per task vs. uniform
ContinualVLATrainer ¶
Bases: TrainerUtils
Sequential task trainer with experience replay support.
Outer loop: iterate over tasks in the CL sequence. Inner loop: standard VLA training on the current task + replay samples.
Source code in AlphaBrain/training/continual_learning/train.py
prepare_training ¶
Initialize training state (checkpoints, freezing, distributed setup).
Source code in AlphaBrain/training/continual_learning/train.py
train ¶
Execute the continual learning training loop.
Source code in AlphaBrain/training/continual_learning/train.py
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 | |
build_full_dataset ¶
Build the full (unfiltered) VLA dataset from config.
build_task_dataloader ¶
Build a DataLoader for a specific task by filtering the full dataset.
Source code in AlphaBrain/training/continual_learning/train.py
Algorithms¶
Base class¶
base ¶
Abstract base class for continual-learning algorithms.
All CL algorithms (Experience Replay / EWC / LwF / SI / GEM / ...) implement this interface. The continual trainer (AlphaBrain.training.continual_learning.train) only talks to the algorithm through this protocol, so new methods can be plugged in without touching the training loop.
Current implementations¶
ReplayBuffer(algorithms.replay_buffer) Reservoir-sampled experience replay with uniform / balanced strategies.
Planned implementations¶
EWCElastic Weight Consolidation (Kirkpatrick et al. 2017)LwFLearning without Forgetting (Li & Hoiem 2017)SISynaptic Intelligence (Zenke et al. 2017)GEMGradient Episodic Memory (Lopez-Paz & Ranzato 2017)
CLAlgorithm ¶
Bases: ABC
Interface every continual-learning algorithm must satisfy.
observe abstractmethod ¶
Called every training step with the current task batch.
Replay-based methods (ER, GEM) use this to grow/refresh memory. Regularization methods (EWC, SI) use it to accumulate importance statistics (Fisher information, path integrals, etc.).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
sample abstractmethod ¶
Return the algorithm's auxiliary artifact for this step (or None/empty to skip).
Return shape is algorithm-specific
- ER / GEM :
list[dict]— raw samples ready to be collated. - EWC / SI :
dict[str, Tensor]— per-parameter regularization terms. - LwF :
dict[str, Tensor]— teacher logits on the current batch.
The trainer dispatches on algorithm type to combine this with the current-task batch (e.g. mix-in ratio for replay, KL term for LwF).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
on_task_end abstractmethod ¶
Hook invoked after a task finishes.
Typical uses: * EWC: snapshot parameters, compute Fisher on current task's dataset. * LwF: snapshot the teacher model weights. * ER : no-op (reservoir sampling happens online).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
state_dict abstractmethod ¶
Return a JSON-serializable snapshot of the algorithm state.
This is written alongside model checkpoints so CL state survives interruption/restart across tasks.
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
load_state_dict abstractmethod ¶
Replay buffer¶
replay_buffer ¶
replay_buffer.py
Experience Replay buffer for continual learning. Stores samples from previously learned tasks and provides mixed batches to mitigate catastrophic forgetting.
Supports: - Reservoir sampling for memory-efficient storage - Per-task buffer management - Configurable replay ratio for batch mixing - Conforms to the CLAlgorithm interface.
ReplayBuffer ¶
Bases: CLAlgorithm
Experience Replay buffer that stores samples from past tasks.
Uses reservoir sampling to maintain a fixed-size buffer per task. During training, samples from the buffer are mixed with current task data at a configurable ratio.
ER is a task-end-populated algorithm (the buffer is filled once after each task finishes via :meth:populate_from_dataset), so the per-step :meth:observe hook is a no-op. The trainer calls :meth:populate_from_dataset directly in its task-end handler.
Usage
buffer = ReplayBuffer(buffer_size_per_task=500)
After finishing task 0:¶
buffer.populate_from_dataset(task_id=0, dataset=task0_dataset)
During task 1 training:¶
replay_samples = buffer.sample(batch_size=4) # list[dict]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
buffer_size_per_task | int | Maximum number of samples stored per task. | 500 |
seed | int | Random seed for reproducibility. | 42 |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
populate_from_dataset ¶
Store samples from a dataset into the buffer using reservoir sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
task_id | int | Identifier for the task. | required |
dataset | Dataset | Dataset to sample from (must support len and getitem). | required |
num_samples | Optional[int] | Number of samples to store. Defaults to buffer_size_per_task. | None |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
sample ¶
Sample a batch uniformly from all stored tasks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | int | Number of samples to return. | required |
Returns:
| Type | Description |
|---|---|
List[dict] | List of sample dicts. Empty list if buffer is empty. |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
sample_balanced ¶
Sample a batch with equal representation from each stored task.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | int | Number of samples to return. | required |
Returns:
| Type | Description |
|---|---|
List[dict] | List of sample dicts. Empty list if buffer is empty. |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
get_task_ids ¶
get_task_size ¶
clear ¶
state_dict ¶
Return serializable state for checkpointing.
Note: only metadata is serialized — the actual sample tensors are not saved (they can be large). On resume, callers must re-populate the buffer by iterating each task's dataset again.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
observe ¶
No-op: ER populates from the full dataset at task-end, not per step.
See :meth:populate_from_dataset for the actual memory update, which the trainer invokes from its task-end handler.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
on_task_end ¶
No-op: the trainer calls :meth:populate_from_dataset directly.
(ER needs the full task dataset object, which a generic no-arg hook cannot provide — so the trainer orchestrates the population explicitly.)
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
load_state_dict ¶
Restore metadata from a snapshot produced by :meth:state_dict.
Only hyperparameters are restored (buffer size, seed). Actual samples must be re-populated from the task datasets on resume.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
Subpackage exports¶
algorithms ¶
Continual-learning algorithms.
See AlphaBrain.training.continual_learning.algorithms.base.CLAlgorithm for the interface every algorithm implements.
CLAlgorithm ¶
Bases: ABC
Interface every continual-learning algorithm must satisfy.
observe abstractmethod ¶
Called every training step with the current task batch.
Replay-based methods (ER, GEM) use this to grow/refresh memory. Regularization methods (EWC, SI) use it to accumulate importance statistics (Fisher information, path integrals, etc.).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
sample abstractmethod ¶
Return the algorithm's auxiliary artifact for this step (or None/empty to skip).
Return shape is algorithm-specific
- ER / GEM :
list[dict]— raw samples ready to be collated. - EWC / SI :
dict[str, Tensor]— per-parameter regularization terms. - LwF :
dict[str, Tensor]— teacher logits on the current batch.
The trainer dispatches on algorithm type to combine this with the current-task batch (e.g. mix-in ratio for replay, KL term for LwF).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
on_task_end abstractmethod ¶
Hook invoked after a task finishes.
Typical uses: * EWC: snapshot parameters, compute Fisher on current task's dataset. * LwF: snapshot the teacher model weights. * ER : no-op (reservoir sampling happens online).
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
state_dict abstractmethod ¶
Return a JSON-serializable snapshot of the algorithm state.
This is written alongside model checkpoints so CL state survives interruption/restart across tasks.
Source code in AlphaBrain/training/continual_learning/algorithms/base.py
load_state_dict abstractmethod ¶
ReplayBuffer ¶
Bases: CLAlgorithm
Experience Replay buffer that stores samples from past tasks.
Uses reservoir sampling to maintain a fixed-size buffer per task. During training, samples from the buffer are mixed with current task data at a configurable ratio.
ER is a task-end-populated algorithm (the buffer is filled once after each task finishes via :meth:populate_from_dataset), so the per-step :meth:observe hook is a no-op. The trainer calls :meth:populate_from_dataset directly in its task-end handler.
Usage
buffer = ReplayBuffer(buffer_size_per_task=500)
After finishing task 0:¶
buffer.populate_from_dataset(task_id=0, dataset=task0_dataset)
During task 1 training:¶
replay_samples = buffer.sample(batch_size=4) # list[dict]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
buffer_size_per_task | int | Maximum number of samples stored per task. | 500 |
seed | int | Random seed for reproducibility. | 42 |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
populate_from_dataset ¶
Store samples from a dataset into the buffer using reservoir sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
task_id | int | Identifier for the task. | required |
dataset | Dataset | Dataset to sample from (must support len and getitem). | required |
num_samples | Optional[int] | Number of samples to store. Defaults to buffer_size_per_task. | None |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
sample ¶
Sample a batch uniformly from all stored tasks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | int | Number of samples to return. | required |
Returns:
| Type | Description |
|---|---|
List[dict] | List of sample dicts. Empty list if buffer is empty. |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
sample_balanced ¶
Sample a batch with equal representation from each stored task.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | int | Number of samples to return. | required |
Returns:
| Type | Description |
|---|---|
List[dict] | List of sample dicts. Empty list if buffer is empty. |
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
get_task_ids ¶
get_task_size ¶
clear ¶
state_dict ¶
Return serializable state for checkpointing.
Note: only metadata is serialized — the actual sample tensors are not saved (they can be large). On resume, callers must re-populate the buffer by iterating each task's dataset again.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
observe ¶
No-op: ER populates from the full dataset at task-end, not per step.
See :meth:populate_from_dataset for the actual memory update, which the trainer invokes from its task-end handler.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
on_task_end ¶
No-op: the trainer calls :meth:populate_from_dataset directly.
(ER needs the full task dataset object, which a generic no-arg hook cannot provide — so the trainer orchestrates the population explicitly.)
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
load_state_dict ¶
Restore metadata from a snapshot produced by :meth:state_dict.
Only hyperparameters are restored (buffer size, seed). Actual samples must be re-populated from the task datasets on resume.
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
Datasets / task sequences¶
task_sequences ¶
continual_learning.py
Defines continual learning task sequences for sequential task training. Each sequence specifies a base data_mix and task ordering. Provides utilities to filter datasets by task_index for per-task training.
TaskFilteredDataset ¶
Bases: Dataset
Wraps a LeRobotMixtureDataset to only expose steps from specific task indices.
This is a lightweight wrapper that filters the base dataset's step sampling without copying data or modifying the underlying dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_dataset | A LeRobotMixtureDataset (or LeRobotSingleDataset). | required | |
task_indices | List[int] | List of task_index values to include. | required |
episode_task_map | Dict[int, List[int]] | Mapping from task_index -> list of episode_ids. | required |
Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
save_dataset_statistics ¶
Delegate to base dataset.
get_task_sequence ¶
Retrieve a CL task sequence by name.
Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
build_episode_task_map ¶
Build mapping from task_index to list of episode_ids by reading episode data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset | A LeRobotSingleDataset instance. | required |
Returns:
| Type | Description |
|---|---|
Dict[int, List[int]] | Dict mapping task_index -> list of trajectory_ids (episode indices). |
Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
datasets ¶
Continual-learning data primitives: task sequences + per-task filtering.
TaskFilteredDataset ¶
Bases: Dataset
Wraps a LeRobotMixtureDataset to only expose steps from specific task indices.
This is a lightweight wrapper that filters the base dataset's step sampling without copying data or modifying the underlying dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_dataset | A LeRobotMixtureDataset (or LeRobotSingleDataset). | required | |
task_indices | List[int] | List of task_index values to include. | required |
episode_task_map | Dict[int, List[int]] | Mapping from task_index -> list of episode_ids. | required |
Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
save_dataset_statistics ¶
Delegate to base dataset.
build_episode_task_map ¶
Build mapping from task_index to list of episode_ids by reading episode data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset | A LeRobotSingleDataset instance. | required |
Returns:
| Type | Description |
|---|---|
Dict[int, List[int]] | Dict mapping task_index -> list of trajectory_ids (episode indices). |
Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
get_task_sequence ¶
Retrieve a CL task sequence by name.