Skip to content

Training › Continual Learning

Source path: AlphaBrain/training/continual_learning/

The Continual Learning (CL) module: algorithm base class, replay buffer, task sequences, and the training-loop entrypoint.

The top-level __init__.py re-exports CLAlgorithm and ReplayBuffer for backwards compatibility.


Top-level re-exports

continual_learning

Continual Learning module.

Sub-packages

algorithms/ — CL algorithms (ReplayBuffer, …) and their CLAlgorithm base. datasets/ — Task sequences and per-task dataset filtering.

Top-level entry

train — Continual training loop (AlphaBrain.training.continual_learning.train.main).

Re-exports for backward compatibility (old import paths still work): ReplayBufferalgorithms.replay_buffer.ReplayBuffer

CLAlgorithm

Bases: ABC

Interface every continual-learning algorithm must satisfy.

name property
name: str

Short identifier used in logs / checkpoints (default = class name).

observe abstractmethod
observe(batch: dict, task_id: int) -> None

Called every training step with the current task batch.

Replay-based methods (ER, GEM) use this to grow/refresh memory. Regularization methods (EWC, SI) use it to accumulate importance statistics (Fisher information, path integrals, etc.).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def observe(self, batch: dict, task_id: int) -> None:
    """Called every training step with the **current task** batch.

    Replay-based methods (ER, GEM) use this to grow/refresh memory.
    Regularization methods (EWC, SI) use it to accumulate importance
    statistics (Fisher information, path integrals, etc.).
    """
sample abstractmethod
sample(batch_size: int) -> Any

Return the algorithm's auxiliary artifact for this step (or None/empty to skip).

Return shape is algorithm-specific
  • ER / GEM : list[dict] — raw samples ready to be collated.
  • EWC / SI : dict[str, Tensor]— per-parameter regularization terms.
  • LwF : dict[str, Tensor]— teacher logits on the current batch.

The trainer dispatches on algorithm type to combine this with the current-task batch (e.g. mix-in ratio for replay, KL term for LwF).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def sample(self, batch_size: int) -> Any:
    """Return the algorithm's auxiliary artifact for this step (or None/empty to skip).

    Return shape is algorithm-specific:
      * ER / GEM : `list[dict]`       — raw samples ready to be collated.
      * EWC / SI : `dict[str, Tensor]`— per-parameter regularization terms.
      * LwF      : `dict[str, Tensor]`— teacher logits on the current batch.

    The trainer dispatches on algorithm type to combine this with the
    current-task batch (e.g. mix-in ratio for replay, KL term for LwF).
    """
on_task_end abstractmethod
on_task_end(task_id: int) -> None

Hook invoked after a task finishes.

Typical uses: * EWC: snapshot parameters, compute Fisher on current task's dataset. * LwF: snapshot the teacher model weights. * ER : no-op (reservoir sampling happens online).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def on_task_end(self, task_id: int) -> None:
    """Hook invoked after a task finishes.

    Typical uses:
    * EWC: snapshot parameters, compute Fisher on current task's dataset.
    * LwF: snapshot the teacher model weights.
    * ER : no-op (reservoir sampling happens online).
    """
state_dict abstractmethod
state_dict() -> dict[str, Any]

Return a JSON-serializable snapshot of the algorithm state.

This is written alongside model checkpoints so CL state survives interruption/restart across tasks.

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def state_dict(self) -> dict[str, Any]:
    """Return a JSON-serializable snapshot of the algorithm state.

    This is written alongside model checkpoints so CL state survives
    interruption/restart across tasks.
    """
load_state_dict abstractmethod
load_state_dict(state: dict[str, Any]) -> None

Restore algorithm state from a dict produced by state_dict().

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def load_state_dict(self, state: dict[str, Any]) -> None:
    """Restore algorithm state from a dict produced by `state_dict()`."""

ReplayBuffer

ReplayBuffer(buffer_size_per_task: int = 500, seed: int = 42)

Bases: CLAlgorithm

Experience Replay buffer that stores samples from past tasks.

Uses reservoir sampling to maintain a fixed-size buffer per task. During training, samples from the buffer are mixed with current task data at a configurable ratio.

ER is a task-end-populated algorithm (the buffer is filled once after each task finishes via :meth:populate_from_dataset), so the per-step :meth:observe hook is a no-op. The trainer calls :meth:populate_from_dataset directly in its task-end handler.

Usage

buffer = ReplayBuffer(buffer_size_per_task=500)

After finishing task 0:

buffer.populate_from_dataset(task_id=0, dataset=task0_dataset)

During task 1 training:

replay_samples = buffer.sample(batch_size=4) # list[dict]

Parameters:

Name Type Description Default
buffer_size_per_task int

Maximum number of samples stored per task.

500
seed int

Random seed for reproducibility.

42
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def __init__(self, buffer_size_per_task: int = 500, seed: int = 42):
    """
    Args:
        buffer_size_per_task: Maximum number of samples stored per task.
        seed: Random seed for reproducibility.
    """
    self.buffer_size_per_task = buffer_size_per_task
    self.seed = seed
    self.rng = random.Random(seed)

    # task_id -> list of stored samples
    self._buffers: Dict[int, List[dict]] = {}
    # Track total samples across all tasks
    self._total_samples = 0
num_tasks property
num_tasks: int

Number of tasks stored in the buffer.

total_samples property
total_samples: int

Total number of samples across all tasks.

populate_from_dataset
populate_from_dataset(task_id: int, dataset: Dataset, num_samples: Optional[int] = None)

Store samples from a dataset into the buffer using reservoir sampling.

Parameters:

Name Type Description Default
task_id int

Identifier for the task.

required
dataset Dataset

Dataset to sample from (must support len and getitem).

required
num_samples Optional[int]

Number of samples to store. Defaults to buffer_size_per_task.

None
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def populate_from_dataset(self, task_id: int, dataset: Dataset, num_samples: Optional[int] = None):
    """Store samples from a dataset into the buffer using reservoir sampling.

    Args:
        task_id: Identifier for the task.
        dataset: Dataset to sample from (must support __len__ and __getitem__).
        num_samples: Number of samples to store. Defaults to buffer_size_per_task.
    """
    if num_samples is None:
        num_samples = self.buffer_size_per_task

    n = len(dataset)
    k = min(num_samples, n)

    # Reservoir sampling: select k items from n uniformly at random
    indices = list(range(n))
    self.rng.shuffle(indices)
    selected_indices = sorted(indices[:k])

    samples = []
    for idx in selected_indices:
        try:
            sample = dataset[idx]
            samples.append(sample)
        except Exception as e:
            logger.warning(f"Failed to read sample {idx} for task {task_id}: {e}")
            continue

    # Update buffer
    if task_id in self._buffers:
        self._total_samples -= len(self._buffers[task_id])
    self._buffers[task_id] = samples
    self._total_samples += len(samples)

    logger.info(
        f"Replay buffer: stored {len(samples)} samples for task {task_id} "
        f"(total: {self._total_samples} across {self.num_tasks} tasks)"
    )
sample
sample(batch_size: int) -> List[dict]

Sample a batch uniformly from all stored tasks.

Parameters:

Name Type Description Default
batch_size int

Number of samples to return.

required

Returns:

Type Description
List[dict]

List of sample dicts. Empty list if buffer is empty.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def sample(self, batch_size: int) -> List[dict]:
    """Sample a batch uniformly from all stored tasks.

    Args:
        batch_size: Number of samples to return.

    Returns:
        List of sample dicts. Empty list if buffer is empty.
    """
    if self.is_empty():
        return []

    # Collect all samples across tasks
    all_samples = []
    for task_samples in self._buffers.values():
        all_samples.extend(task_samples)

    # Sample without replacement when possible to maximize diversity
    if batch_size <= len(all_samples):
        return self.rng.sample(all_samples, k=batch_size)
    else:
        return self.rng.choices(all_samples, k=batch_size)
sample_balanced
sample_balanced(batch_size: int) -> List[dict]

Sample a batch with equal representation from each stored task.

Parameters:

Name Type Description Default
batch_size int

Number of samples to return.

required

Returns:

Type Description
List[dict]

List of sample dicts. Empty list if buffer is empty.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def sample_balanced(self, batch_size: int) -> List[dict]:
    """Sample a batch with equal representation from each stored task.

    Args:
        batch_size: Number of samples to return.

    Returns:
        List of sample dicts. Empty list if buffer is empty.
    """
    if self.is_empty():
        return []

    samples_per_task = max(1, batch_size // self.num_tasks)
    result = []

    for task_samples in self._buffers.values():
        k = min(samples_per_task, len(task_samples))
        result.extend(self.rng.choices(task_samples, k=k))

    # If we need more samples to reach batch_size, sample randomly
    while len(result) < batch_size:
        task_id = self.rng.choice(list(self._buffers.keys()))
        result.append(self.rng.choice(self._buffers[task_id]))

    return result[:batch_size]
get_task_ids
get_task_ids() -> List[int]

Return list of task IDs stored in the buffer.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def get_task_ids(self) -> List[int]:
    """Return list of task IDs stored in the buffer."""
    return sorted(self._buffers.keys())
get_task_size
get_task_size(task_id: int) -> int

Return number of samples stored for a specific task.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def get_task_size(self, task_id: int) -> int:
    """Return number of samples stored for a specific task."""
    return len(self._buffers.get(task_id, []))
clear
clear()

Clear all stored samples.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def clear(self):
    """Clear all stored samples."""
    self._buffers.clear()
    self._total_samples = 0
state_dict
state_dict() -> dict

Return serializable state for checkpointing.

Note: only metadata is serialized — the actual sample tensors are not saved (they can be large). On resume, callers must re-populate the buffer by iterating each task's dataset again.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def state_dict(self) -> dict:
    """Return serializable state for checkpointing.

    Note: only metadata is serialized — the actual sample tensors are not
    saved (they can be large).  On resume, callers must re-populate the
    buffer by iterating each task's dataset again.
    """
    return {
        "algorithm": self.name,
        "buffer_size_per_task": self.buffer_size_per_task,
        "seed": self.seed,
        "num_tasks": self.num_tasks,
        "total_samples": self._total_samples,
        "task_sizes": {k: len(v) for k, v in self._buffers.items()},
    }
observe
observe(batch: dict, task_id: int) -> None

No-op: ER populates from the full dataset at task-end, not per step.

See :meth:populate_from_dataset for the actual memory update, which the trainer invokes from its task-end handler.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def observe(self, batch: dict, task_id: int) -> None:
    """No-op: ER populates from the full dataset at task-end, not per step.

    See :meth:`populate_from_dataset` for the actual memory update, which
    the trainer invokes from its task-end handler.
    """
    return None
on_task_end
on_task_end(task_id: int) -> None

No-op: the trainer calls :meth:populate_from_dataset directly.

(ER needs the full task dataset object, which a generic no-arg hook cannot provide — so the trainer orchestrates the population explicitly.)

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def on_task_end(self, task_id: int) -> None:
    """No-op: the trainer calls :meth:`populate_from_dataset` directly.

    (ER needs the full task dataset object, which a generic no-arg hook
    cannot provide — so the trainer orchestrates the population explicitly.)
    """
    return None
load_state_dict
load_state_dict(state: Dict[str, Any]) -> None

Restore metadata from a snapshot produced by :meth:state_dict.

Only hyperparameters are restored (buffer size, seed). Actual samples must be re-populated from the task datasets on resume.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def load_state_dict(self, state: Dict[str, Any]) -> None:
    """Restore metadata from a snapshot produced by :meth:`state_dict`.

    Only hyperparameters are restored (buffer size, seed).  Actual samples
    must be re-populated from the task datasets on resume.
    """
    self.buffer_size_per_task = state.get(
        "buffer_size_per_task", self.buffer_size_per_task
    )
    self.seed = state.get("seed", self.seed)
    self.rng = random.Random(self.seed)
    self._buffers = {}
    self._total_samples = 0

Training entrypoint

train

Continual Learning Trainer for AlphaBrain.

Trains a VLA model sequentially on a stream of tasks, with optional Experience Replay to mitigate catastrophic forgetting.

Design: - Follows the framework convention of one trainer file per training strategy. - Reuses existing build_framework, build_dataloader, and TrainerUtils. - Adds an outer loop over tasks and integrates the replay buffer.

Config

Add a continual_learning section to your YAML:

continual_learning: task_sequence: libero_spatial # CL sequence name (see continual_learning.py) steps_per_task: 10000 # training steps per task save_checkpoint_per_task: true # save after each task

replay: enabled: true method: experience_replay # replay method (currently only ER) buffer_size_per_task: 500 # samples to store per past task replay_batch_ratio: 0.3 # fraction of each batch from replay balanced_sampling: false # equal samples per task vs. uniform

ContinualVLATrainer

ContinualVLATrainer(cfg, model, optimizer, lr_scheduler, accelerator)

Bases: TrainerUtils

Sequential task trainer with experience replay support.

Outer loop: iterate over tasks in the CL sequence. Inner loop: standard VLA training on the current task + replay samples.

Source code in AlphaBrain/training/continual_learning/train.py
def __init__(self, cfg, model, optimizer, lr_scheduler, accelerator):
    self.config = cfg
    self.model = model
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.accelerator = accelerator

    self.cl_cfg = cfg.continual_learning
    self.completed_steps = 0
    self.total_batch_size = (
        cfg.datasets.vla_data.per_device_batch_size
        * accelerator.num_processes
        * getattr(cfg.trainer, "gradient_accumulation_steps", 1)
    )

    # LoRA
    from AlphaBrain.training.trainer_utils.peft import is_lora_enabled
    self.use_lora = is_lora_enabled(cfg)

    # Replay buffer
    replay_cfg = self.cl_cfg.replay
    self.replay_enabled = replay_cfg.get("enabled", False)
    if self.replay_enabled:
        self.replay_buffer = ReplayBuffer(
            buffer_size_per_task=replay_cfg.get("buffer_size_per_task", 500),
            seed=cfg.get("seed", 42),
        )
        self.replay_batch_ratio = replay_cfg.get("replay_batch_ratio", 0.3)
        self.balanced_sampling = replay_cfg.get("balanced_sampling", False)
    else:
        self.replay_buffer = None
prepare_training
prepare_training()

Initialize training state (checkpoints, freezing, distributed setup).

Source code in AlphaBrain/training/continual_learning/train.py
def prepare_training(self):
    """Initialize training state (checkpoints, freezing, distributed setup)."""
    rank = dist.get_rank() if dist.is_initialized() else 0
    seed = self.config.seed + rank if hasattr(self.config, "seed") else rank + 3047
    set_seed(seed)

    self._init_checkpointing()

    freeze_modules = (
        self.config.trainer.freeze_modules
        if hasattr(self.config.trainer, "freeze_modules")
        else None
    )
    self.model = self.freeze_backbones(self.model, freeze_modules=freeze_modules)
    self.print_trainable_parameters(self.model)

    # NOTE: we prepare model and optimizer here, dataloaders are prepared per-task.
    # DeepSpeed requires train_micro_batch_size_per_gpu when no dataloader is passed
    # to accelerator.prepare(), so set it explicitly from config.
    if hasattr(self.accelerator.state, "deepspeed_plugin") and self.accelerator.state.deepspeed_plugin is not None:
        ds_cfg = self.accelerator.state.deepspeed_plugin.deepspeed_config
        micro_bs = self.config.datasets.vla_data.per_device_batch_size
        if ds_cfg.get("train_micro_batch_size_per_gpu") == "auto":
            ds_cfg["train_micro_batch_size_per_gpu"] = micro_bs
        if ds_cfg.get("gradient_accumulation_steps") == "auto":
            ds_cfg["gradient_accumulation_steps"] = getattr(self.config.trainer, "gradient_accumulation_steps", 1)
        if ds_cfg.get("train_batch_size") == "auto":
            grad_acc = ds_cfg.get("gradient_accumulation_steps", 1)
            ds_cfg["train_batch_size"] = micro_bs * self.accelerator.num_processes * grad_acc
    self.model, self.optimizer = self.setup_distributed_training(
        self.accelerator, self.model, self.optimizer
    )

    self._init_wandb()
train
train(full_dataset, episode_task_map)

Execute the continual learning training loop.

Source code in AlphaBrain/training/continual_learning/train.py
def train(self, full_dataset, episode_task_map):
    """Execute the continual learning training loop."""
    seq_cfg = get_task_sequence(self.cl_cfg.task_sequence)
    num_tasks = seq_cfg["num_tasks"]
    task_order = seq_cfg.get("task_order", list(range(num_tasks)))
    steps_per_task = self.cl_cfg.steps_per_task
    save_per_task = self.cl_cfg.get("save_checkpoint_per_task", True)

    # Determine which tasks to skip based on completed steps
    start_task_idx = self.completed_steps // steps_per_task
    if start_task_idx > 0:
        logger.info(
            f"Resuming from step {self.completed_steps}: "
            f"skipping {start_task_idx} completed tasks"
        )
        # Rebuild replay buffer from completed tasks
        if self.replay_enabled:
            for skip_idx in range(start_task_idx):
                skip_task_id = task_order[skip_idx]
                _, skip_dataset = build_task_dataloader(
                    full_dataset, skip_task_id, episode_task_map, self.config
                )
                self.replay_buffer.populate_from_dataset(
                    task_id=skip_task_id, dataset=skip_dataset,
                )
            logger.info(
                f"Rebuilt replay buffer for {start_task_idx} completed tasks: "
                f"{self.replay_buffer}"
            )

    self._log_cl_config(num_tasks, steps_per_task)

    # Outer tqdm bar over the CL task sequence (1..num_tasks).
    # disable on non-main ranks to avoid duplicate bars under DeepSpeed.
    task_pbar = tqdm(
        enumerate(task_order),
        total=num_tasks,
        desc="CL tasks",
        disable=not self.accelerator.is_local_main_process,
        initial=start_task_idx,
    )
    for task_idx_in_seq, task_id in task_pbar:
        # Skip already completed tasks
        if task_idx_in_seq < start_task_idx:
            task_pbar.write(
                f"Skipping Task {task_idx_in_seq + 1}/{num_tasks} "
                f"(task_index={task_id}) — already completed"
            )
            continue

        task_pbar.set_postfix(task_id=task_id, step=self.completed_steps)
        task_pbar.write(f"{'='*60}")
        task_pbar.write(
            f"Starting Task {task_idx_in_seq + 1}/{num_tasks} "
            f"(task_index={task_id})"
        )
        task_pbar.write(f"{'='*60}")

        # Build per-task dataloader
        task_dataloader, task_dataset = build_task_dataloader(
            full_dataset, task_id, episode_task_map, self.config
        )

        # Prepare dataloader for distributed training
        self.accelerator.dataloader_config.dispatch_batches = False
        task_dataloader = self.accelerator.prepare(task_dataloader)
        dist.barrier()

        # Reset LR scheduler for new task
        # Unwrap through AcceleratedOptimizer → DeepSpeedZeroOptimizer → base AdamW
        base_optimizer = self.optimizer
        while hasattr(base_optimizer, "optimizer"):
            base_optimizer = base_optimizer.optimizer
        task_lr_scheduler = get_scheduler(
            name=self.config.trainer.lr_scheduler_type,
            optimizer=base_optimizer,
            num_warmup_steps=self.config.trainer.num_warmup_steps,
            num_training_steps=steps_per_task,
            scheduler_specific_kwargs=self.config.trainer.scheduler_specific_kwargs,
        )

        # Train on current task
        self._train_single_task(
            task_id=task_id,
            task_idx_in_seq=task_idx_in_seq,
            num_tasks=num_tasks,
            task_dataloader=task_dataloader,
            lr_scheduler=task_lr_scheduler,
            steps_per_task=steps_per_task,
        )

        # Post-task: populate replay buffer
        if self.replay_enabled:
            logger.info(f"Populating replay buffer for task {task_id}...")
            self.replay_buffer.populate_from_dataset(
                task_id=task_id,
                dataset=task_dataset,
            )
            logger.info(f"Replay buffer state: {self.replay_buffer}")

        # Save checkpoint after task
        if save_per_task:
            self._save_task_checkpoint(task_id, task_idx_in_seq)

        dist.barrier()

    self._finalize_training()

build_full_dataset

build_full_dataset(cfg)

Build the full (unfiltered) VLA dataset from config.

Source code in AlphaBrain/training/continual_learning/train.py
def build_full_dataset(cfg):
    """Build the full (unfiltered) VLA dataset from config."""
    vla_dataset_cfg = cfg.datasets.vla_data
    return get_vla_dataset(data_cfg=vla_dataset_cfg)

build_task_dataloader

build_task_dataloader(full_dataset, task_index, episode_task_map, cfg)

Build a DataLoader for a specific task by filtering the full dataset.

Source code in AlphaBrain/training/continual_learning/train.py
def build_task_dataloader(full_dataset, task_index, episode_task_map, cfg):
    """Build a DataLoader for a specific task by filtering the full dataset."""
    filtered_dataset = TaskFilteredDataset(
        base_dataset=full_dataset,
        task_indices=[task_index],
        episode_task_map=episode_task_map,
    )
    dataloader = DataLoader(
        filtered_dataset,
        batch_size=cfg.datasets.vla_data.per_device_batch_size,
        collate_fn=collate_fn,
        num_workers=4,
        shuffle=True,
    )
    return dataloader, filtered_dataset

Algorithms

Base class

base

Abstract base class for continual-learning algorithms.

All CL algorithms (Experience Replay / EWC / LwF / SI / GEM / ...) implement this interface. The continual trainer (AlphaBrain.training.continual_learning.train) only talks to the algorithm through this protocol, so new methods can be plugged in without touching the training loop.

Current implementations
  • ReplayBuffer (algorithms.replay_buffer) Reservoir-sampled experience replay with uniform / balanced strategies.
Planned implementations
  • EWC Elastic Weight Consolidation (Kirkpatrick et al. 2017)
  • LwF Learning without Forgetting (Li & Hoiem 2017)
  • SI Synaptic Intelligence (Zenke et al. 2017)
  • GEM Gradient Episodic Memory (Lopez-Paz & Ranzato 2017)
CLAlgorithm

Bases: ABC

Interface every continual-learning algorithm must satisfy.

name property
name: str

Short identifier used in logs / checkpoints (default = class name).

observe abstractmethod
observe(batch: dict, task_id: int) -> None

Called every training step with the current task batch.

Replay-based methods (ER, GEM) use this to grow/refresh memory. Regularization methods (EWC, SI) use it to accumulate importance statistics (Fisher information, path integrals, etc.).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def observe(self, batch: dict, task_id: int) -> None:
    """Called every training step with the **current task** batch.

    Replay-based methods (ER, GEM) use this to grow/refresh memory.
    Regularization methods (EWC, SI) use it to accumulate importance
    statistics (Fisher information, path integrals, etc.).
    """
sample abstractmethod
sample(batch_size: int) -> Any

Return the algorithm's auxiliary artifact for this step (or None/empty to skip).

Return shape is algorithm-specific
  • ER / GEM : list[dict] — raw samples ready to be collated.
  • EWC / SI : dict[str, Tensor]— per-parameter regularization terms.
  • LwF : dict[str, Tensor]— teacher logits on the current batch.

The trainer dispatches on algorithm type to combine this with the current-task batch (e.g. mix-in ratio for replay, KL term for LwF).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def sample(self, batch_size: int) -> Any:
    """Return the algorithm's auxiliary artifact for this step (or None/empty to skip).

    Return shape is algorithm-specific:
      * ER / GEM : `list[dict]`       — raw samples ready to be collated.
      * EWC / SI : `dict[str, Tensor]`— per-parameter regularization terms.
      * LwF      : `dict[str, Tensor]`— teacher logits on the current batch.

    The trainer dispatches on algorithm type to combine this with the
    current-task batch (e.g. mix-in ratio for replay, KL term for LwF).
    """
on_task_end abstractmethod
on_task_end(task_id: int) -> None

Hook invoked after a task finishes.

Typical uses: * EWC: snapshot parameters, compute Fisher on current task's dataset. * LwF: snapshot the teacher model weights. * ER : no-op (reservoir sampling happens online).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def on_task_end(self, task_id: int) -> None:
    """Hook invoked after a task finishes.

    Typical uses:
    * EWC: snapshot parameters, compute Fisher on current task's dataset.
    * LwF: snapshot the teacher model weights.
    * ER : no-op (reservoir sampling happens online).
    """
state_dict abstractmethod
state_dict() -> dict[str, Any]

Return a JSON-serializable snapshot of the algorithm state.

This is written alongside model checkpoints so CL state survives interruption/restart across tasks.

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def state_dict(self) -> dict[str, Any]:
    """Return a JSON-serializable snapshot of the algorithm state.

    This is written alongside model checkpoints so CL state survives
    interruption/restart across tasks.
    """
load_state_dict abstractmethod
load_state_dict(state: dict[str, Any]) -> None

Restore algorithm state from a dict produced by state_dict().

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def load_state_dict(self, state: dict[str, Any]) -> None:
    """Restore algorithm state from a dict produced by `state_dict()`."""

Replay buffer

replay_buffer

replay_buffer.py

Experience Replay buffer for continual learning. Stores samples from previously learned tasks and provides mixed batches to mitigate catastrophic forgetting.

Supports: - Reservoir sampling for memory-efficient storage - Per-task buffer management - Configurable replay ratio for batch mixing - Conforms to the CLAlgorithm interface.

ReplayBuffer
ReplayBuffer(buffer_size_per_task: int = 500, seed: int = 42)

Bases: CLAlgorithm

Experience Replay buffer that stores samples from past tasks.

Uses reservoir sampling to maintain a fixed-size buffer per task. During training, samples from the buffer are mixed with current task data at a configurable ratio.

ER is a task-end-populated algorithm (the buffer is filled once after each task finishes via :meth:populate_from_dataset), so the per-step :meth:observe hook is a no-op. The trainer calls :meth:populate_from_dataset directly in its task-end handler.

Usage

buffer = ReplayBuffer(buffer_size_per_task=500)

After finishing task 0:

buffer.populate_from_dataset(task_id=0, dataset=task0_dataset)

During task 1 training:

replay_samples = buffer.sample(batch_size=4) # list[dict]

Parameters:

Name Type Description Default
buffer_size_per_task int

Maximum number of samples stored per task.

500
seed int

Random seed for reproducibility.

42
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def __init__(self, buffer_size_per_task: int = 500, seed: int = 42):
    """
    Args:
        buffer_size_per_task: Maximum number of samples stored per task.
        seed: Random seed for reproducibility.
    """
    self.buffer_size_per_task = buffer_size_per_task
    self.seed = seed
    self.rng = random.Random(seed)

    # task_id -> list of stored samples
    self._buffers: Dict[int, List[dict]] = {}
    # Track total samples across all tasks
    self._total_samples = 0
num_tasks property
num_tasks: int

Number of tasks stored in the buffer.

total_samples property
total_samples: int

Total number of samples across all tasks.

populate_from_dataset
populate_from_dataset(task_id: int, dataset: Dataset, num_samples: Optional[int] = None)

Store samples from a dataset into the buffer using reservoir sampling.

Parameters:

Name Type Description Default
task_id int

Identifier for the task.

required
dataset Dataset

Dataset to sample from (must support len and getitem).

required
num_samples Optional[int]

Number of samples to store. Defaults to buffer_size_per_task.

None
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def populate_from_dataset(self, task_id: int, dataset: Dataset, num_samples: Optional[int] = None):
    """Store samples from a dataset into the buffer using reservoir sampling.

    Args:
        task_id: Identifier for the task.
        dataset: Dataset to sample from (must support __len__ and __getitem__).
        num_samples: Number of samples to store. Defaults to buffer_size_per_task.
    """
    if num_samples is None:
        num_samples = self.buffer_size_per_task

    n = len(dataset)
    k = min(num_samples, n)

    # Reservoir sampling: select k items from n uniformly at random
    indices = list(range(n))
    self.rng.shuffle(indices)
    selected_indices = sorted(indices[:k])

    samples = []
    for idx in selected_indices:
        try:
            sample = dataset[idx]
            samples.append(sample)
        except Exception as e:
            logger.warning(f"Failed to read sample {idx} for task {task_id}: {e}")
            continue

    # Update buffer
    if task_id in self._buffers:
        self._total_samples -= len(self._buffers[task_id])
    self._buffers[task_id] = samples
    self._total_samples += len(samples)

    logger.info(
        f"Replay buffer: stored {len(samples)} samples for task {task_id} "
        f"(total: {self._total_samples} across {self.num_tasks} tasks)"
    )
sample
sample(batch_size: int) -> List[dict]

Sample a batch uniformly from all stored tasks.

Parameters:

Name Type Description Default
batch_size int

Number of samples to return.

required

Returns:

Type Description
List[dict]

List of sample dicts. Empty list if buffer is empty.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def sample(self, batch_size: int) -> List[dict]:
    """Sample a batch uniformly from all stored tasks.

    Args:
        batch_size: Number of samples to return.

    Returns:
        List of sample dicts. Empty list if buffer is empty.
    """
    if self.is_empty():
        return []

    # Collect all samples across tasks
    all_samples = []
    for task_samples in self._buffers.values():
        all_samples.extend(task_samples)

    # Sample without replacement when possible to maximize diversity
    if batch_size <= len(all_samples):
        return self.rng.sample(all_samples, k=batch_size)
    else:
        return self.rng.choices(all_samples, k=batch_size)
sample_balanced
sample_balanced(batch_size: int) -> List[dict]

Sample a batch with equal representation from each stored task.

Parameters:

Name Type Description Default
batch_size int

Number of samples to return.

required

Returns:

Type Description
List[dict]

List of sample dicts. Empty list if buffer is empty.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def sample_balanced(self, batch_size: int) -> List[dict]:
    """Sample a batch with equal representation from each stored task.

    Args:
        batch_size: Number of samples to return.

    Returns:
        List of sample dicts. Empty list if buffer is empty.
    """
    if self.is_empty():
        return []

    samples_per_task = max(1, batch_size // self.num_tasks)
    result = []

    for task_samples in self._buffers.values():
        k = min(samples_per_task, len(task_samples))
        result.extend(self.rng.choices(task_samples, k=k))

    # If we need more samples to reach batch_size, sample randomly
    while len(result) < batch_size:
        task_id = self.rng.choice(list(self._buffers.keys()))
        result.append(self.rng.choice(self._buffers[task_id]))

    return result[:batch_size]
get_task_ids
get_task_ids() -> List[int]

Return list of task IDs stored in the buffer.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def get_task_ids(self) -> List[int]:
    """Return list of task IDs stored in the buffer."""
    return sorted(self._buffers.keys())
get_task_size
get_task_size(task_id: int) -> int

Return number of samples stored for a specific task.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def get_task_size(self, task_id: int) -> int:
    """Return number of samples stored for a specific task."""
    return len(self._buffers.get(task_id, []))
clear
clear()

Clear all stored samples.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def clear(self):
    """Clear all stored samples."""
    self._buffers.clear()
    self._total_samples = 0
state_dict
state_dict() -> dict

Return serializable state for checkpointing.

Note: only metadata is serialized — the actual sample tensors are not saved (they can be large). On resume, callers must re-populate the buffer by iterating each task's dataset again.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def state_dict(self) -> dict:
    """Return serializable state for checkpointing.

    Note: only metadata is serialized — the actual sample tensors are not
    saved (they can be large).  On resume, callers must re-populate the
    buffer by iterating each task's dataset again.
    """
    return {
        "algorithm": self.name,
        "buffer_size_per_task": self.buffer_size_per_task,
        "seed": self.seed,
        "num_tasks": self.num_tasks,
        "total_samples": self._total_samples,
        "task_sizes": {k: len(v) for k, v in self._buffers.items()},
    }
observe
observe(batch: dict, task_id: int) -> None

No-op: ER populates from the full dataset at task-end, not per step.

See :meth:populate_from_dataset for the actual memory update, which the trainer invokes from its task-end handler.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def observe(self, batch: dict, task_id: int) -> None:
    """No-op: ER populates from the full dataset at task-end, not per step.

    See :meth:`populate_from_dataset` for the actual memory update, which
    the trainer invokes from its task-end handler.
    """
    return None
on_task_end
on_task_end(task_id: int) -> None

No-op: the trainer calls :meth:populate_from_dataset directly.

(ER needs the full task dataset object, which a generic no-arg hook cannot provide — so the trainer orchestrates the population explicitly.)

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def on_task_end(self, task_id: int) -> None:
    """No-op: the trainer calls :meth:`populate_from_dataset` directly.

    (ER needs the full task dataset object, which a generic no-arg hook
    cannot provide — so the trainer orchestrates the population explicitly.)
    """
    return None
load_state_dict
load_state_dict(state: Dict[str, Any]) -> None

Restore metadata from a snapshot produced by :meth:state_dict.

Only hyperparameters are restored (buffer size, seed). Actual samples must be re-populated from the task datasets on resume.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def load_state_dict(self, state: Dict[str, Any]) -> None:
    """Restore metadata from a snapshot produced by :meth:`state_dict`.

    Only hyperparameters are restored (buffer size, seed).  Actual samples
    must be re-populated from the task datasets on resume.
    """
    self.buffer_size_per_task = state.get(
        "buffer_size_per_task", self.buffer_size_per_task
    )
    self.seed = state.get("seed", self.seed)
    self.rng = random.Random(self.seed)
    self._buffers = {}
    self._total_samples = 0

Subpackage exports

algorithms

Continual-learning algorithms.

See AlphaBrain.training.continual_learning.algorithms.base.CLAlgorithm for the interface every algorithm implements.

CLAlgorithm

Bases: ABC

Interface every continual-learning algorithm must satisfy.

name property
name: str

Short identifier used in logs / checkpoints (default = class name).

observe abstractmethod
observe(batch: dict, task_id: int) -> None

Called every training step with the current task batch.

Replay-based methods (ER, GEM) use this to grow/refresh memory. Regularization methods (EWC, SI) use it to accumulate importance statistics (Fisher information, path integrals, etc.).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def observe(self, batch: dict, task_id: int) -> None:
    """Called every training step with the **current task** batch.

    Replay-based methods (ER, GEM) use this to grow/refresh memory.
    Regularization methods (EWC, SI) use it to accumulate importance
    statistics (Fisher information, path integrals, etc.).
    """
sample abstractmethod
sample(batch_size: int) -> Any

Return the algorithm's auxiliary artifact for this step (or None/empty to skip).

Return shape is algorithm-specific
  • ER / GEM : list[dict] — raw samples ready to be collated.
  • EWC / SI : dict[str, Tensor]— per-parameter regularization terms.
  • LwF : dict[str, Tensor]— teacher logits on the current batch.

The trainer dispatches on algorithm type to combine this with the current-task batch (e.g. mix-in ratio for replay, KL term for LwF).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def sample(self, batch_size: int) -> Any:
    """Return the algorithm's auxiliary artifact for this step (or None/empty to skip).

    Return shape is algorithm-specific:
      * ER / GEM : `list[dict]`       — raw samples ready to be collated.
      * EWC / SI : `dict[str, Tensor]`— per-parameter regularization terms.
      * LwF      : `dict[str, Tensor]`— teacher logits on the current batch.

    The trainer dispatches on algorithm type to combine this with the
    current-task batch (e.g. mix-in ratio for replay, KL term for LwF).
    """
on_task_end abstractmethod
on_task_end(task_id: int) -> None

Hook invoked after a task finishes.

Typical uses: * EWC: snapshot parameters, compute Fisher on current task's dataset. * LwF: snapshot the teacher model weights. * ER : no-op (reservoir sampling happens online).

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def on_task_end(self, task_id: int) -> None:
    """Hook invoked after a task finishes.

    Typical uses:
    * EWC: snapshot parameters, compute Fisher on current task's dataset.
    * LwF: snapshot the teacher model weights.
    * ER : no-op (reservoir sampling happens online).
    """
state_dict abstractmethod
state_dict() -> dict[str, Any]

Return a JSON-serializable snapshot of the algorithm state.

This is written alongside model checkpoints so CL state survives interruption/restart across tasks.

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def state_dict(self) -> dict[str, Any]:
    """Return a JSON-serializable snapshot of the algorithm state.

    This is written alongside model checkpoints so CL state survives
    interruption/restart across tasks.
    """
load_state_dict abstractmethod
load_state_dict(state: dict[str, Any]) -> None

Restore algorithm state from a dict produced by state_dict().

Source code in AlphaBrain/training/continual_learning/algorithms/base.py
@abstractmethod
def load_state_dict(self, state: dict[str, Any]) -> None:
    """Restore algorithm state from a dict produced by `state_dict()`."""
ReplayBuffer
ReplayBuffer(buffer_size_per_task: int = 500, seed: int = 42)

Bases: CLAlgorithm

Experience Replay buffer that stores samples from past tasks.

Uses reservoir sampling to maintain a fixed-size buffer per task. During training, samples from the buffer are mixed with current task data at a configurable ratio.

ER is a task-end-populated algorithm (the buffer is filled once after each task finishes via :meth:populate_from_dataset), so the per-step :meth:observe hook is a no-op. The trainer calls :meth:populate_from_dataset directly in its task-end handler.

Usage

buffer = ReplayBuffer(buffer_size_per_task=500)

After finishing task 0:

buffer.populate_from_dataset(task_id=0, dataset=task0_dataset)

During task 1 training:

replay_samples = buffer.sample(batch_size=4) # list[dict]

Parameters:

Name Type Description Default
buffer_size_per_task int

Maximum number of samples stored per task.

500
seed int

Random seed for reproducibility.

42
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def __init__(self, buffer_size_per_task: int = 500, seed: int = 42):
    """
    Args:
        buffer_size_per_task: Maximum number of samples stored per task.
        seed: Random seed for reproducibility.
    """
    self.buffer_size_per_task = buffer_size_per_task
    self.seed = seed
    self.rng = random.Random(seed)

    # task_id -> list of stored samples
    self._buffers: Dict[int, List[dict]] = {}
    # Track total samples across all tasks
    self._total_samples = 0
num_tasks property
num_tasks: int

Number of tasks stored in the buffer.

total_samples property
total_samples: int

Total number of samples across all tasks.

populate_from_dataset
populate_from_dataset(task_id: int, dataset: Dataset, num_samples: Optional[int] = None)

Store samples from a dataset into the buffer using reservoir sampling.

Parameters:

Name Type Description Default
task_id int

Identifier for the task.

required
dataset Dataset

Dataset to sample from (must support len and getitem).

required
num_samples Optional[int]

Number of samples to store. Defaults to buffer_size_per_task.

None
Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def populate_from_dataset(self, task_id: int, dataset: Dataset, num_samples: Optional[int] = None):
    """Store samples from a dataset into the buffer using reservoir sampling.

    Args:
        task_id: Identifier for the task.
        dataset: Dataset to sample from (must support __len__ and __getitem__).
        num_samples: Number of samples to store. Defaults to buffer_size_per_task.
    """
    if num_samples is None:
        num_samples = self.buffer_size_per_task

    n = len(dataset)
    k = min(num_samples, n)

    # Reservoir sampling: select k items from n uniformly at random
    indices = list(range(n))
    self.rng.shuffle(indices)
    selected_indices = sorted(indices[:k])

    samples = []
    for idx in selected_indices:
        try:
            sample = dataset[idx]
            samples.append(sample)
        except Exception as e:
            logger.warning(f"Failed to read sample {idx} for task {task_id}: {e}")
            continue

    # Update buffer
    if task_id in self._buffers:
        self._total_samples -= len(self._buffers[task_id])
    self._buffers[task_id] = samples
    self._total_samples += len(samples)

    logger.info(
        f"Replay buffer: stored {len(samples)} samples for task {task_id} "
        f"(total: {self._total_samples} across {self.num_tasks} tasks)"
    )
sample
sample(batch_size: int) -> List[dict]

Sample a batch uniformly from all stored tasks.

Parameters:

Name Type Description Default
batch_size int

Number of samples to return.

required

Returns:

Type Description
List[dict]

List of sample dicts. Empty list if buffer is empty.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def sample(self, batch_size: int) -> List[dict]:
    """Sample a batch uniformly from all stored tasks.

    Args:
        batch_size: Number of samples to return.

    Returns:
        List of sample dicts. Empty list if buffer is empty.
    """
    if self.is_empty():
        return []

    # Collect all samples across tasks
    all_samples = []
    for task_samples in self._buffers.values():
        all_samples.extend(task_samples)

    # Sample without replacement when possible to maximize diversity
    if batch_size <= len(all_samples):
        return self.rng.sample(all_samples, k=batch_size)
    else:
        return self.rng.choices(all_samples, k=batch_size)
sample_balanced
sample_balanced(batch_size: int) -> List[dict]

Sample a batch with equal representation from each stored task.

Parameters:

Name Type Description Default
batch_size int

Number of samples to return.

required

Returns:

Type Description
List[dict]

List of sample dicts. Empty list if buffer is empty.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def sample_balanced(self, batch_size: int) -> List[dict]:
    """Sample a batch with equal representation from each stored task.

    Args:
        batch_size: Number of samples to return.

    Returns:
        List of sample dicts. Empty list if buffer is empty.
    """
    if self.is_empty():
        return []

    samples_per_task = max(1, batch_size // self.num_tasks)
    result = []

    for task_samples in self._buffers.values():
        k = min(samples_per_task, len(task_samples))
        result.extend(self.rng.choices(task_samples, k=k))

    # If we need more samples to reach batch_size, sample randomly
    while len(result) < batch_size:
        task_id = self.rng.choice(list(self._buffers.keys()))
        result.append(self.rng.choice(self._buffers[task_id]))

    return result[:batch_size]
get_task_ids
get_task_ids() -> List[int]

Return list of task IDs stored in the buffer.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def get_task_ids(self) -> List[int]:
    """Return list of task IDs stored in the buffer."""
    return sorted(self._buffers.keys())
get_task_size
get_task_size(task_id: int) -> int

Return number of samples stored for a specific task.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def get_task_size(self, task_id: int) -> int:
    """Return number of samples stored for a specific task."""
    return len(self._buffers.get(task_id, []))
clear
clear()

Clear all stored samples.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def clear(self):
    """Clear all stored samples."""
    self._buffers.clear()
    self._total_samples = 0
state_dict
state_dict() -> dict

Return serializable state for checkpointing.

Note: only metadata is serialized — the actual sample tensors are not saved (they can be large). On resume, callers must re-populate the buffer by iterating each task's dataset again.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def state_dict(self) -> dict:
    """Return serializable state for checkpointing.

    Note: only metadata is serialized — the actual sample tensors are not
    saved (they can be large).  On resume, callers must re-populate the
    buffer by iterating each task's dataset again.
    """
    return {
        "algorithm": self.name,
        "buffer_size_per_task": self.buffer_size_per_task,
        "seed": self.seed,
        "num_tasks": self.num_tasks,
        "total_samples": self._total_samples,
        "task_sizes": {k: len(v) for k, v in self._buffers.items()},
    }
observe
observe(batch: dict, task_id: int) -> None

No-op: ER populates from the full dataset at task-end, not per step.

See :meth:populate_from_dataset for the actual memory update, which the trainer invokes from its task-end handler.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def observe(self, batch: dict, task_id: int) -> None:
    """No-op: ER populates from the full dataset at task-end, not per step.

    See :meth:`populate_from_dataset` for the actual memory update, which
    the trainer invokes from its task-end handler.
    """
    return None
on_task_end
on_task_end(task_id: int) -> None

No-op: the trainer calls :meth:populate_from_dataset directly.

(ER needs the full task dataset object, which a generic no-arg hook cannot provide — so the trainer orchestrates the population explicitly.)

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def on_task_end(self, task_id: int) -> None:
    """No-op: the trainer calls :meth:`populate_from_dataset` directly.

    (ER needs the full task dataset object, which a generic no-arg hook
    cannot provide — so the trainer orchestrates the population explicitly.)
    """
    return None
load_state_dict
load_state_dict(state: Dict[str, Any]) -> None

Restore metadata from a snapshot produced by :meth:state_dict.

Only hyperparameters are restored (buffer size, seed). Actual samples must be re-populated from the task datasets on resume.

Source code in AlphaBrain/training/continual_learning/algorithms/replay_buffer.py
def load_state_dict(self, state: Dict[str, Any]) -> None:
    """Restore metadata from a snapshot produced by :meth:`state_dict`.

    Only hyperparameters are restored (buffer size, seed).  Actual samples
    must be re-populated from the task datasets on resume.
    """
    self.buffer_size_per_task = state.get(
        "buffer_size_per_task", self.buffer_size_per_task
    )
    self.seed = state.get("seed", self.seed)
    self.rng = random.Random(self.seed)
    self._buffers = {}
    self._total_samples = 0

Datasets / task sequences

task_sequences

continual_learning.py

Defines continual learning task sequences for sequential task training. Each sequence specifies a base data_mix and task ordering. Provides utilities to filter datasets by task_index for per-task training.

TaskFilteredDataset
TaskFilteredDataset(base_dataset, task_indices: List[int], episode_task_map: Dict[int, List[int]])

Bases: Dataset

Wraps a LeRobotMixtureDataset to only expose steps from specific task indices.

This is a lightweight wrapper that filters the base dataset's step sampling without copying data or modifying the underlying dataset.

Parameters:

Name Type Description Default
base_dataset

A LeRobotMixtureDataset (or LeRobotSingleDataset).

required
task_indices List[int]

List of task_index values to include.

required
episode_task_map Dict[int, List[int]]

Mapping from task_index -> list of episode_ids.

required
Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
def __init__(self, base_dataset, task_indices: List[int], episode_task_map: Dict[int, List[int]]):
    """
    Args:
        base_dataset: A LeRobotMixtureDataset (or LeRobotSingleDataset).
        task_indices: List of task_index values to include.
        episode_task_map: Mapping from task_index -> list of episode_ids.
    """
    self.base_dataset = base_dataset
    self.task_indices = task_indices

    # Build set of valid episode ids for fast lookup
    self.valid_episodes = set()
    for ti in task_indices:
        if ti in episode_task_map:
            self.valid_episodes.update(episode_task_map[ti])

    # For MixtureDataset: filter each sub-dataset's all_steps
    # For SingleDataset: filter directly
    if hasattr(base_dataset, 'datasets'):
        # MixtureDataset
        self._filtered_steps_per_dataset = []
        self._total_steps = 0
        for ds in base_dataset.datasets:
            filtered = [
                (traj_id, base_idx)
                for traj_id, base_idx in ds.all_steps
                if traj_id in self.valid_episodes
            ]
            self._filtered_steps_per_dataset.append(filtered)
            self._total_steps += len(filtered)
    else:
        # SingleDataset
        self._filtered_steps = [
            (traj_id, base_idx)
            for traj_id, base_idx in base_dataset.all_steps
            if traj_id in self.valid_episodes
        ]
        self._total_steps = len(self._filtered_steps)

    logger.info(
        f"TaskFilteredDataset: tasks={task_indices}, "
        f"episodes={len(self.valid_episodes)}, steps={self._total_steps}"
    )
datasets property
datasets

Expose underlying datasets for compatibility.

save_dataset_statistics
save_dataset_statistics(path)

Delegate to base dataset.

Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
def save_dataset_statistics(self, path):
    """Delegate to base dataset."""
    if hasattr(self.base_dataset, 'save_dataset_statistics'):
        self.base_dataset.save_dataset_statistics(path)
get_task_sequence
get_task_sequence(sequence_name: str) -> dict

Retrieve a CL task sequence by name.

Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
def get_task_sequence(sequence_name: str) -> dict:
    """Retrieve a CL task sequence by name."""
    if sequence_name not in CL_TASK_SEQUENCES:
        raise ValueError(
            f"Unknown CL task sequence: {sequence_name}. "
            f"Available: {list(CL_TASK_SEQUENCES.keys())}"
        )
    return CL_TASK_SEQUENCES[sequence_name]
build_episode_task_map
build_episode_task_map(dataset) -> Dict[int, List[int]]

Build mapping from task_index to list of episode_ids by reading episode data.

Parameters:

Name Type Description Default
dataset

A LeRobotSingleDataset instance.

required

Returns:

Type Description
Dict[int, List[int]]

Dict mapping task_index -> list of trajectory_ids (episode indices).

Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
def build_episode_task_map(dataset) -> Dict[int, List[int]]:
    """Build mapping from task_index to list of episode_ids by reading episode data.

    Args:
        dataset: A LeRobotSingleDataset instance.

    Returns:
        Dict mapping task_index -> list of trajectory_ids (episode indices).
    """
    task_to_episodes: Dict[int, List[int]] = defaultdict(list)
    seen_episodes = set()

    for traj_id in dataset.trajectory_ids:
        if traj_id in seen_episodes:
            continue
        seen_episodes.add(traj_id)

        try:
            data = dataset.get_trajectory_data(traj_id)
            if "task_index" in data.columns:
                task_idx = int(data["task_index"].iloc[0])
            else:
                # Fallback: try annotation-based task index
                annotation_cols = [c for c in data.columns if "task" in c.lower()]
                if annotation_cols:
                    task_idx = int(data[annotation_cols[0]].iloc[0])
                else:
                    logger.warning(
                        f"No task_index column found for episode {traj_id}, assigning task 0"
                    )
                    task_idx = 0
            task_to_episodes[task_idx].append(traj_id)
        except Exception as e:
            logger.warning(f"Failed to read task_index for episode {traj_id}: {e}")
            continue

    # Clear dataset cache
    dataset.curr_traj_data = None
    dataset.curr_traj_id = None

    logger.info(
        f"Built episode-task map: {len(task_to_episodes)} tasks, "
        f"{sum(len(v) for v in task_to_episodes.values())} total episodes"
    )
    for task_idx in sorted(task_to_episodes.keys()):
        logger.info(f"  Task {task_idx}: {len(task_to_episodes[task_idx])} episodes")

    return dict(task_to_episodes)

datasets

Continual-learning data primitives: task sequences + per-task filtering.

TaskFilteredDataset
TaskFilteredDataset(base_dataset, task_indices: List[int], episode_task_map: Dict[int, List[int]])

Bases: Dataset

Wraps a LeRobotMixtureDataset to only expose steps from specific task indices.

This is a lightweight wrapper that filters the base dataset's step sampling without copying data or modifying the underlying dataset.

Parameters:

Name Type Description Default
base_dataset

A LeRobotMixtureDataset (or LeRobotSingleDataset).

required
task_indices List[int]

List of task_index values to include.

required
episode_task_map Dict[int, List[int]]

Mapping from task_index -> list of episode_ids.

required
Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
def __init__(self, base_dataset, task_indices: List[int], episode_task_map: Dict[int, List[int]]):
    """
    Args:
        base_dataset: A LeRobotMixtureDataset (or LeRobotSingleDataset).
        task_indices: List of task_index values to include.
        episode_task_map: Mapping from task_index -> list of episode_ids.
    """
    self.base_dataset = base_dataset
    self.task_indices = task_indices

    # Build set of valid episode ids for fast lookup
    self.valid_episodes = set()
    for ti in task_indices:
        if ti in episode_task_map:
            self.valid_episodes.update(episode_task_map[ti])

    # For MixtureDataset: filter each sub-dataset's all_steps
    # For SingleDataset: filter directly
    if hasattr(base_dataset, 'datasets'):
        # MixtureDataset
        self._filtered_steps_per_dataset = []
        self._total_steps = 0
        for ds in base_dataset.datasets:
            filtered = [
                (traj_id, base_idx)
                for traj_id, base_idx in ds.all_steps
                if traj_id in self.valid_episodes
            ]
            self._filtered_steps_per_dataset.append(filtered)
            self._total_steps += len(filtered)
    else:
        # SingleDataset
        self._filtered_steps = [
            (traj_id, base_idx)
            for traj_id, base_idx in base_dataset.all_steps
            if traj_id in self.valid_episodes
        ]
        self._total_steps = len(self._filtered_steps)

    logger.info(
        f"TaskFilteredDataset: tasks={task_indices}, "
        f"episodes={len(self.valid_episodes)}, steps={self._total_steps}"
    )
datasets property
datasets

Expose underlying datasets for compatibility.

save_dataset_statistics
save_dataset_statistics(path)

Delegate to base dataset.

Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
def save_dataset_statistics(self, path):
    """Delegate to base dataset."""
    if hasattr(self.base_dataset, 'save_dataset_statistics'):
        self.base_dataset.save_dataset_statistics(path)
build_episode_task_map
build_episode_task_map(dataset) -> Dict[int, List[int]]

Build mapping from task_index to list of episode_ids by reading episode data.

Parameters:

Name Type Description Default
dataset

A LeRobotSingleDataset instance.

required

Returns:

Type Description
Dict[int, List[int]]

Dict mapping task_index -> list of trajectory_ids (episode indices).

Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
def build_episode_task_map(dataset) -> Dict[int, List[int]]:
    """Build mapping from task_index to list of episode_ids by reading episode data.

    Args:
        dataset: A LeRobotSingleDataset instance.

    Returns:
        Dict mapping task_index -> list of trajectory_ids (episode indices).
    """
    task_to_episodes: Dict[int, List[int]] = defaultdict(list)
    seen_episodes = set()

    for traj_id in dataset.trajectory_ids:
        if traj_id in seen_episodes:
            continue
        seen_episodes.add(traj_id)

        try:
            data = dataset.get_trajectory_data(traj_id)
            if "task_index" in data.columns:
                task_idx = int(data["task_index"].iloc[0])
            else:
                # Fallback: try annotation-based task index
                annotation_cols = [c for c in data.columns if "task" in c.lower()]
                if annotation_cols:
                    task_idx = int(data[annotation_cols[0]].iloc[0])
                else:
                    logger.warning(
                        f"No task_index column found for episode {traj_id}, assigning task 0"
                    )
                    task_idx = 0
            task_to_episodes[task_idx].append(traj_id)
        except Exception as e:
            logger.warning(f"Failed to read task_index for episode {traj_id}: {e}")
            continue

    # Clear dataset cache
    dataset.curr_traj_data = None
    dataset.curr_traj_id = None

    logger.info(
        f"Built episode-task map: {len(task_to_episodes)} tasks, "
        f"{sum(len(v) for v in task_to_episodes.values())} total episodes"
    )
    for task_idx in sorted(task_to_episodes.keys()):
        logger.info(f"  Task {task_idx}: {len(task_to_episodes[task_idx])} episodes")

    return dict(task_to_episodes)
get_task_sequence
get_task_sequence(sequence_name: str) -> dict

Retrieve a CL task sequence by name.

Source code in AlphaBrain/training/continual_learning/datasets/task_sequences.py
def get_task_sequence(sequence_name: str) -> dict:
    """Retrieve a CL task sequence by name."""
    if sequence_name not in CL_TASK_SEQUENCES:
        raise ValueError(
            f"Unknown CL task sequence: {sequence_name}. "
            f"Available: {list(CL_TASK_SEQUENCES.keys())}"
        )
    return CL_TASK_SEQUENCES[sequence_name]