Skip to content

Training › General Training

Source path: AlphaBrain/training/

Generic training entrypoints and shared utilities for VLA models. Continual Learning and Reinforcement Learning have their own pages:


Training entrypoints

train_alphabrain.py — main training entrypoint

train_alphabrain

AlphaBrain’s trainer is built directly on native PyTorch + Accelerate + DeepSpeed, keeping the loop explicit and easy to hack. Conventions: 1. Store runtime state in dicts where possible (simplifies data info, procesing info, config, etc). 2. Use multiple dataloaders to adapt heterogeneous data types / task mixtures. 3. Put each training strategy in its own trainer_*.py file (avoid large if‑else chains).

VLATrainer
VLATrainer(cfg, model, vla_train_dataloader, optimizer, lr_scheduler, accelerator)

Bases: TrainerUtils

Source code in AlphaBrain/training/train_alphabrain.py
def __init__(self, cfg, model, vla_train_dataloader, optimizer, lr_scheduler, accelerator):
    self.config = cfg
    self.model = model
    self.vla_train_dataloader = vla_train_dataloader
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.accelerator = accelerator

    # LoRA
    from AlphaBrain.training.trainer_utils.peft import is_lora_enabled
    self.use_lora = is_lora_enabled(cfg)

    # training status tracking
    self.completed_steps = 0
    self.total_batch_size = self._calculate_total_batch_size()

    # EMA (Exponential Moving Average)
    ema_cfg = getattr(cfg.trainer, 'ema', None)
    self.use_ema = ema_cfg is not None and getattr(ema_cfg, 'enabled', False)
    self.ema_decay = getattr(ema_cfg, 'decay', 0.99) if ema_cfg else 0.99
    self.ema_model = None  # initialized after distributed setup
train
train()

execute training loop

Source code in AlphaBrain/training/train_alphabrain.py
def train(self):
    """execute training loop"""
    # print training config
    self._log_training_config()

    # prepare data iterators
    self._create_data_iterators()

    # create progress bar
    progress_bar = tqdm(
        range(self.config.trainer.max_train_steps),
        initial=self.completed_steps,
        disable=not self.accelerator.is_local_main_process
    )

    # main training loop
    while self.completed_steps < self.config.trainer.max_train_steps:
        # get data batch
        t_start_data = time.perf_counter()
        batch_vla = self._get_next_batch()
        t_end_data = time.perf_counter()

        # execute training step
        t_start_model = time.perf_counter()
        step_metrics = self._train_step(batch_vla)
        t_end_model = time.perf_counter()

        # update progress
        if self.accelerator.sync_gradients:
            progress_bar.update(1)
            self.completed_steps += 1

        if self.accelerator.is_local_main_process:
            _postfix = {
                "action_dit_loss": f"{step_metrics.get('action_dit_loss', 0):.4f}",
                "data": f"{t_end_data - t_start_data:.3f}s",
                "fwd": f"{t_end_model - t_start_model:.3f}s",
            }
            if "video_loss" in step_metrics:
                _postfix["video_loss"] = f"{step_metrics['video_loss']:.4f}"
                _postfix["total_loss"] = f"{step_metrics.get('total_loss', 0):.4f}"
            progress_bar.set_postfix(_postfix)

        # evaluate model

        if self.completed_steps % self.config.trainer.eval_interval == 0:
            try:
                step_metrics = self.eval_action_model(step_metrics)
            except Exception as e:
                if self.accelerator.is_main_process:
                    logger.warning(f"eval_action_model failed: {e}, skipping")


        # record metrics
        step_metrics["data_time"] = t_end_data - t_start_data
        step_metrics["model_time"] = t_end_model - t_start_model
        self._log_metrics(step_metrics)

        # save checkpoint
        if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0:
            self._save_checkpoint()

        # check termination condition
        if self.completed_steps >= self.config.trainer.max_train_steps:
            break

    # training end processing
    self._finalize_training()
eval_action_model
eval_action_model(step_metrics: dict = None) -> float

Evaluate the model on the given dataset using the specified metric function.

:param eval_dataset: List of evaluation samples, each containing 'image', 'instruction', and 'action'. :param metric_fn: Function to compute the distance between predicted and ground truth actions. :return: Average metric score across the evaluation dataset.

Source code in AlphaBrain/training/train_alphabrain.py
def eval_action_model(self, step_metrics: dict = None) -> float:
    """
    Evaluate the model on the given dataset using the specified metric function.

    :param eval_dataset: List of evaluation samples, each containing 'image', 'instruction', and 'action'.
    :param metric_fn: Function to compute the distance between predicted and ground truth actions.
    :return: Average metric score across the evaluation dataset.
    """

    examples = self._get_next_batch()
    if examples is None:
        logger.warning('eval_action_model: got None batch, skipping')
        return step_metrics if step_metrics else {}
    score = 0.0
    num_samples = len(examples)
    batch_images = [example["image"] for example in examples]
    instructions = [example["lang"] for example in examples]
    actions = [example["action"] for example in examples]  # label
    states = [example["state"] for example in examples] if "state" in examples[0] else None
    # Predict actions using the model
    output_dict = self.model.predict_action(
        batch_images=batch_images, instructions=instructions, states=states,
        use_ddim=True, num_ddim_steps=20
    )

    if self.accelerator.is_main_process:
        normalized_actions = output_dict["normalized_actions"]  # B, T, D
        actions = np.array(actions)  # convert actions to numpy.ndarray
        # B, Chunk, dim = actions.shape
        num_elements = np.prod(actions.shape)
        # Compute the metric score
        score = TrainerUtils.euclidean_distance(normalized_actions, actions)
        average_score = score / num_elements
        step_metrics["mse_score"] = average_score

    del examples
    if dist.is_initialized():
        dist.barrier()  # ensure all processes are synchronized
    return step_metrics
setup_file_logging
setup_file_logging(output_dir: str, rank: int = 0)

Add a FileHandler to root logger so all log messages are saved to a local file. Only the main process (rank 0) writes to avoid multi-process file conflicts.

Source code in AlphaBrain/training/train_alphabrain.py
def setup_file_logging(output_dir: str, rank: int = 0):
    """Add a FileHandler to root logger so all log messages are saved to a local file.
    Only the main process (rank 0) writes to avoid multi-process file conflicts.
    """
    if rank != 0:
        return None
    log_dir = os.path.join(output_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")

    file_handler = logging.FileHandler(log_file, encoding="utf-8")
    file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        "[%(asctime)s][%(name)s][%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    file_handler.setFormatter(formatter)
    logging.getLogger().addHandler(file_handler)
    return log_file
setup_directories
setup_directories(cfg) -> Path

create output directory and save config

Source code in AlphaBrain/training/train_alphabrain.py
def setup_directories(cfg) -> Path:
    """create output directory and save config"""
    cfg.output_dir = os.path.join(cfg.output_root_dir, cfg.run_id)
    output_dir = Path(cfg.output_dir)

    if not dist.is_initialized() or dist.get_rank() == 0:
        # create output directory and checkpoint directory
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(output_dir / "checkpoints", exist_ok=True)

        # setup file logging to save logs locally
        log_file = setup_file_logging(str(output_dir), rank=0)
        if log_file:
            logger.info(f"Training logs will be saved to: {log_file}")

        # # save config
        # OmegaConf.save(cfg, output_dir / "config.yaml")
        # with open(output_dir / "config.yaml", "r") as f_yaml, open(output_dir / "config.json", "w") as f_json:
        #     yaml_cfg = yaml.safe_load(f_yaml)
        #     json.dump(yaml_cfg, f_json, indent=2)

    return output_dir
build_model
build_model(cfg) -> torch.nn.Module

build model framework

Source code in AlphaBrain/training/train_alphabrain.py
def build_model(cfg) -> torch.nn.Module:
    """build model framework"""
    if hasattr(cfg.framework, 'qwenvl') and hasattr(cfg.framework.qwenvl, 'base_vlm'):
        logger.info(f"Loading Base VLM `{cfg.framework.qwenvl.base_vlm}` from ID/Path")
    else:
        logger.info(f"Building framework: {cfg.framework.name}")
    model = build_framework(cfg)

    return model
prepare_data
prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]

prepare training data

Source code in AlphaBrain/training/train_alphabrain.py
def prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]:
    """prepare training data"""
    # VLA data loader
    dataset_mix = getattr(cfg.datasets.vla_data, 'dataset_mix', 'N/A')
    logger.info(f"Creating VLA Dataset with Mixture `{dataset_mix}`")
    vla_train_dataloader = build_dataloader(cfg=cfg, dataloader_module=cfg.datasets.vla_data.dataloader_module)

    accelerator.dataloader_config.dispatch_batches = False
    if dist.is_initialized():
        dist.barrier()

    return vla_train_dataloader
setup_optimizer_and_scheduler
setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]

set optimizer and scheduler

Source code in AlphaBrain/training/train_alphabrain.py
def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
    """set optimizer and scheduler"""
    # initialize optimizer
    param_groups = build_param_lr_groups(model=model, cfg=cfg)
    optimizer = torch.optim.AdamW(
        param_groups,
        lr=cfg.trainer.learning_rate.base,
        betas=tuple(cfg.trainer.optimizer.betas),
        weight_decay=cfg.trainer.optimizer.weight_decay,
        eps=cfg.trainer.optimizer.eps,
    )

    # print optimizer group info
    if dist.is_initialized() and dist.get_rank() == 0:
        for i, group in enumerate(optimizer.param_groups):
            logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}")

    # initialize learning rate scheduler
    scheduler_type = getattr(cfg.trainer, 'scheduler_type', None)
    if scheduler_type == 'lambda_linear':
        from omegaconf import OmegaConf
        t = cfg.trainer
        cycle_lengths = list(OmegaConf.to_container(t.cycle_lengths, resolve=True))
        warm_up_steps = list(OmegaConf.to_container(t.warm_up_steps, resolve=True))
        f_start = list(OmegaConf.to_container(t.f_start, resolve=True))
        f_max = list(OmegaConf.to_container(t.f_max, resolve=True))
        f_min = list(OmegaConf.to_container(t.f_min, resolve=True))
        lr_scheduler = _build_lambda_linear_scheduler(
            optimizer, cycle_lengths, warm_up_steps, f_start, f_max, f_min
        )
    else:
        lr_scheduler = get_scheduler(
            name=cfg.trainer.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=cfg.trainer.num_warmup_steps,
            num_training_steps=cfg.trainer.max_train_steps,
            scheduler_specific_kwargs=cfg.trainer.scheduler_specific_kwargs,  # minimum learning rate
        )

    return optimizer, lr_scheduler

train_alphabrain_cotrain.py — co-training

train_alphabrain_cotrain

AlphaBrain’s trainer is built directly on native PyTorch + Accelerate + DeepSpeed, keeping the loop explicit and easy to hack. Conventions: 1. Store runtime state in dicts where possible (simplifies data info, procesing info, config, etc). 2. Use multiple dataloaders to adapt heterogeneous data types / task mixtures. 3. Put each training strategy in its own trainer_*.py file (avoid large if‑else chains).

VLATrainer
VLATrainer(cfg, model, vla_train_dataloader, vlm_train_dataloader, optimizer, lr_scheduler, accelerator)

Bases: TrainerUtils

Source code in AlphaBrain/training/train_alphabrain_cotrain.py
def __init__(self, cfg, model, vla_train_dataloader, vlm_train_dataloader, optimizer, lr_scheduler, accelerator):
    self.config = cfg
    self.model = model
    self.vla_train_dataloader = vla_train_dataloader
    self.vlm_train_dataloader = vlm_train_dataloader
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.accelerator = accelerator

    self.completed_steps = 0
    self.total_batch_size = self._calculate_total_batch_size()
train
train()

Execute training loop.

Source code in AlphaBrain/training/train_alphabrain_cotrain.py
def train(self):
    """Execute training loop."""
    self._log_training_config()
    self._create_data_iterators()
    progress_bar = tqdm(
        range(self.config.trainer.max_train_steps), disable=not self.accelerator.is_local_main_process
    )

    while self.completed_steps < self.config.trainer.max_train_steps:
        t_start_data = time.perf_counter()
        batch_vla, batch_vlm = self._get_next_batch()
        t_end_data = time.perf_counter()

        t_start_model = time.perf_counter()
        step_metrics = self._train_step(batch_vla, batch_vlm)
        t_end_model = time.perf_counter()

        if self.accelerator.sync_gradients:
            progress_bar.update(1)
            self.completed_steps += 1

        if self.accelerator.is_local_main_process:
            progress_bar.set_postfix(
                {
                    "data_times": f"{t_end_data - t_start_data:.3f}",
                    "model_times": f"{t_end_model - t_start_model:.3f}",
                }
            )

        if self.completed_steps % self.config.trainer.eval_interval == 0:
            step_metrics = self.eval_action_model(step_metrics)

        step_metrics["data_time"] = t_end_data - t_start_data
        step_metrics["model_time"] = t_end_model - t_start_model
        self._log_metrics(step_metrics)

        if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0:
            self._save_checkpoint()
            dist.barrier()

        if self.completed_steps >= self.config.trainer.max_train_steps:
            break

    self._finalize_training()
eval_action_model
eval_action_model(step_metrics: dict = None) -> float

Evaluate action prediction with current model.

Source code in AlphaBrain/training/train_alphabrain_cotrain.py
def eval_action_model(self, step_metrics: dict = None) -> float:
    """Evaluate action prediction with current model."""
    if self.accelerator.is_main_process:
        examples, _ = self._get_next_batch()
        actions = [example["action"] for example in examples]

        output_dict = self.model.predict_action(examples=examples)
        normalized_actions = output_dict["normalized_actions"]

        actions = np.array(actions)
        num_elements = np.prod(actions.shape)
        score = TrainerUtils.euclidean_distance(normalized_actions, actions)
        step_metrics["mse_score"] = score / num_elements

    dist.barrier()
    return step_metrics
setup_file_logging
setup_file_logging(output_dir: str, rank: int = 0)

Add a FileHandler to root logger so all log messages are saved to a local file. Only the main process (rank 0) writes to avoid multi-process file conflicts.

Source code in AlphaBrain/training/train_alphabrain_cotrain.py
def setup_file_logging(output_dir: str, rank: int = 0):
    """Add a FileHandler to root logger so all log messages are saved to a local file.
    Only the main process (rank 0) writes to avoid multi-process file conflicts.
    """
    if rank != 0:
        return None
    log_dir = os.path.join(output_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")

    file_handler = logging.FileHandler(log_file, encoding="utf-8")
    file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        "[%(asctime)s][%(name)s][%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    file_handler.setFormatter(formatter)
    logging.getLogger().addHandler(file_handler)
    return log_file
setup_directories
setup_directories(cfg) -> Path

Create output directory and checkpoint directory.

Source code in AlphaBrain/training/train_alphabrain_cotrain.py
def setup_directories(cfg) -> Path:
    """Create output directory and checkpoint directory."""
    cfg.output_dir = os.path.join(cfg.output_root_dir, cfg.run_id)
    output_dir = Path(cfg.output_dir)

    if not dist.is_initialized() or dist.get_rank() == 0:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(output_dir / "checkpoints", exist_ok=True)

        # setup file logging to save logs locally
        log_file = setup_file_logging(str(output_dir), rank=0)
        if log_file:
            logger.info(f"Training logs will be saved to: {log_file}")

    return output_dir
prepare_data
prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]

Prepare co-training data.

Source code in AlphaBrain/training/train_alphabrain_cotrain.py
def prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]:
    """Prepare co-training data."""
    logger.info(f"Creating VLA Dataset with Mixture `{cfg.datasets.vla_data.dataset_mix}`")
    vla_train_dataloader = build_dataloader(cfg=cfg, dataloader_module=cfg.datasets.vla_data.dataloader_module)
    vlm_train_dataloader = build_dataloader(cfg=cfg, dataloader_module=cfg.datasets.vlm_data.dataloader_module)

    accelerator.dataloader_config.dispatch_batches = False
    dist.barrier()
    return vla_train_dataloader, vlm_train_dataloader
setup_optimizer_and_scheduler
setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]

Set optimizer and learning rate scheduler.

Source code in AlphaBrain/training/train_alphabrain_cotrain.py
def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
    """Set optimizer and learning rate scheduler."""
    param_groups = build_param_lr_groups(model=model, cfg=cfg)
    optimizer = torch.optim.AdamW(
        param_groups,
        lr=cfg.trainer.learning_rate.base,
        betas=tuple(cfg.trainer.optimizer.betas),
        weight_decay=cfg.trainer.optimizer.weight_decay,
        eps=cfg.trainer.optimizer.eps,
    )

    if dist.is_initialized() and dist.get_rank() == 0:
        for group in optimizer.param_groups:
            logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}")

    lr_scheduler = get_scheduler(
        name=cfg.trainer.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=cfg.trainer.num_warmup_steps,
        num_training_steps=cfg.trainer.max_train_steps,
        scheduler_specific_kwargs=cfg.trainer.scheduler_specific_kwargs,
    )

    return optimizer, lr_scheduler

train_alphabrain_vlm.py — VLM-only training

train_alphabrain_vlm

AlphaBrain’s trainer is built directly on native PyTorch + Accelerate + DeepSpeed, keeping the loop explicit and easy to hack. Conventions: 1. Store runtime state in dicts where possible (simplifies data info, procesing info, config, etc). 2. Use multiple dataloaders to adapt heterogeneous data types / task mixtures. 3. Put each training strategy in its own trainer_*.py file (avoid large if‑else chains).

VLATrainer
VLATrainer(cfg, model, vlm_train_dataloader, optimizer, lr_scheduler, accelerator)

Bases: TrainerUtils

Source code in AlphaBrain/training/train_alphabrain_vlm.py
def __init__(self, cfg, model, vlm_train_dataloader, optimizer, lr_scheduler, accelerator):
    self.config = cfg
    self.model = model
    self.vlm_train_dataloader = vlm_train_dataloader
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.accelerator = accelerator

    self.completed_steps = 0
    self.total_batch_size = self._calculate_total_batch_size()
train
train()

Execute training loop.

Source code in AlphaBrain/training/train_alphabrain_vlm.py
def train(self):
    """Execute training loop."""
    self._log_training_config()
    self._create_data_iterators()
    progress_bar = tqdm(
        range(self.config.trainer.max_train_steps), disable=not self.accelerator.is_local_main_process
    )

    while self.completed_steps < self.config.trainer.max_train_steps:
        batch_vlm = self._get_next_batch()
        step_metrics = self._train_step(batch_vlm)

        if self.accelerator.sync_gradients:
            progress_bar.update(1)
            self.completed_steps += 1

        if self.completed_steps % self.config.trainer.eval_interval == 0:
            step_metrics = self.eval_action_model(step_metrics)

        self._log_metrics(step_metrics)

        if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0:
            self._save_checkpoint()
            dist.barrier()

        if self.completed_steps >= self.config.trainer.max_train_steps:
            break

    self._finalize_training()
eval_action_model
eval_action_model(step_metrics=None)

No-op evaluation for VLM-only training.

Source code in AlphaBrain/training/train_alphabrain_vlm.py
def eval_action_model(self, step_metrics=None):
    """No-op evaluation for VLM-only training."""
    return step_metrics or {}
setup_directories
setup_directories(cfg) -> Path

Create output directory and checkpoint directory.

Source code in AlphaBrain/training/train_alphabrain_vlm.py
def setup_directories(cfg) -> Path:
    """Create output directory and checkpoint directory."""
    cfg.output_dir = os.path.join(cfg.output_root_dir, cfg.run_id)
    output_dir = Path(cfg.output_dir)

    if not dist.is_initialized() or dist.get_rank() == 0:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(output_dir / "checkpoints", exist_ok=True)

    return output_dir
prepare_data
prepare_data(cfg, accelerator, output_dir) -> DataLoader

Prepare VLM training data.

Source code in AlphaBrain/training/train_alphabrain_vlm.py
def prepare_data(cfg, accelerator, output_dir) -> DataLoader:
    """Prepare VLM training data."""
    logger.info(f"Creating VLM Dataset `{cfg.datasets.vlm_data.dataset_use}`")
    vlm_train_dataloader = build_dataloader(cfg=cfg, dataloader_module=cfg.datasets.vlm_data.dataloader_module)

    accelerator.dataloader_config.dispatch_batches = False
    dist.barrier()
    return vlm_train_dataloader
setup_optimizer_and_scheduler
setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]

Set optimizer and learning rate scheduler.

Source code in AlphaBrain/training/train_alphabrain_vlm.py
def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
    """Set optimizer and learning rate scheduler."""
    param_groups = build_param_lr_groups(model=model, cfg=cfg)
    optimizer = torch.optim.AdamW(
        param_groups,
        lr=cfg.trainer.learning_rate.base,
        betas=tuple(cfg.trainer.optimizer.betas),
        weight_decay=cfg.trainer.optimizer.weight_decay,
        eps=cfg.trainer.optimizer.eps,
    )

    if dist.is_initialized() and dist.get_rank() == 0:
        for group in optimizer.param_groups:
            logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}")

    lr_scheduler = get_scheduler(
        name=cfg.trainer.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=cfg.trainer.num_warmup_steps,
        num_training_steps=cfg.trainer.max_train_steps,
        scheduler_specific_kwargs=cfg.trainer.scheduler_specific_kwargs,
    )

    return optimizer, lr_scheduler

train_stdp.py — STDP spiking-model training

train_stdp

STDP Fine-tuning Training Script for NeuroVLA.

This script loads a pretrained NeuroVLA checkpoint and fine-tunes the SNN action head using Reward-Modulated STDP (R-STDP), optionally blended with standard backpropagation gradients.

Modes
  • hybrid: Δw = α·Δw_backprop + β·Δw_rstdp (default)
  • pure_stdp: Δw = Δw_rstdp only (no backprop for SNN weights)
Usage

accelerate launch AlphaBrain/training/train_stdp.py --config_yaml configs/finetune_config.yaml --mode neuro_vla_stdp

STDPTrainer
STDPTrainer(cfg, model, dataloader, optimizer, lr_scheduler, accelerator)

Bases: TrainerUtils

Trainer for R-STDP fine-tuning of NeuroVLA.

Extends the standard training loop with: 1. SpikeMonitor to record spike timing from LIF layers 2. STDPLearner to compute STDP weight updates 3. RSTDPOptimizer to blend backprop and STDP updates

Source code in AlphaBrain/training/train_stdp.py
def __init__(self, cfg, model, dataloader, optimizer, lr_scheduler, accelerator):
    self.config = cfg
    self.model = model
    self.dataloader = dataloader
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.accelerator = accelerator

    self.completed_steps = 0
    self.total_batch_size = self._calculate_total_batch_size()

    # STDP configuration
    stdp_cfg = cfg.stdp if hasattr(cfg, "stdp") else OmegaConf.create({})
    self.stdp_enabled = getattr(stdp_cfg, "enabled", True)
    self.stdp_mode = getattr(stdp_cfg, "mode", "hybrid")
    self.alpha = getattr(stdp_cfg, "alpha", 0.7)
    self.beta = getattr(stdp_cfg, "beta", 0.3)

    # STDP components (initialized in prepare_training)
    self.spike_monitor = None
    self.stdp_learner = None
    self.rstdp_optimizer = None

    # EMA reward tracker for smoother R-STDP signal
    self._ema_loss: float = None
    self._ema_decay: float = 0.95

Trainer utilities

Shared training utilities: structured logging (overwatch), PEFT, finetune configuration, checkpoint tracking, and more.

Overwatch (unified logging)

overwatch

overwatch.py

Original file from OpenVLA project (Prismatic), licensed under MIT License.
See https://github.com/openvla/openvla for full license text and contributors.
Modified by @JinhuiYE, [2025]

Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler.

DistributedOverwatch
DistributedOverwatch(name: str)

Initializer for an Overwatch object that wraps logging & accelerate.PartialState.

Source code in AlphaBrain/training/trainer_utils/overwatch.py
def __init__(self, name: str) -> None:
    """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`."""
    from accelerate import PartialState

    # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun`
    #   =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all!
    self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState()

    # Logger Delegation
    self.debug = self.logger.debug
    self.info = self.logger.info
    self.warning = self.logger.warning
    self.error = self.logger.error
    self.critical = self.logger.critical

    # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others!
    self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR)
PureOverwatch
PureOverwatch(name: str)

Initializer for an Overwatch object that just wraps logging.

Source code in AlphaBrain/training/trainer_utils/overwatch.py
def __init__(self, name: str) -> None:
    """Initializer for an Overwatch object that just wraps logging."""
    self.logger = ContextAdapter(logging.getLogger(name), extra={})

    # Logger Delegation
    self.debug = self.logger.debug
    self.info = self.logger.info
    self.warning = self.logger.warning
    self.error = self.logger.error
    self.critical = self.logger.critical

    # Logging Defaults =>> INFO
    self.logger.setLevel(logging.INFO)

Finetune configuration

finetune_config

Utilities for loading finetune_config.yaml as the primary training config.

Merge order (lowest → highest priority): configs/models/.yaml < configs/datasets/.yaml < configs/trainer/.yaml < train_recipe (if mode.config_yaml is set) < finetune_config global sections (environment, seed) < mode-derived field mappings < mode.framework / mode.datasets / mode.trainer direct overrides < mode.extra_args < CLI args (applied by caller)

expand_env_vars
expand_env_vars(value)

Expand bash-style ${VAR} / ${VAR:-default} in a string. No-op for non-strings.

Source code in AlphaBrain/training/trainer_utils/finetune_config.py
def expand_env_vars(value):
    """Expand bash-style ${VAR} / ${VAR:-default} in a string. No-op for non-strings."""
    if not isinstance(value, str):
        return value
    def _replace(m):
        var, default = m.group(1), m.group(3)
        return os.environ.get(var, default if default is not None else "")
    return re.sub(r'\$\{([A-Za-z_][A-Za-z0-9_]*)(:-(.*?))?\}', _replace, value)
build_config_from_finetune
build_config_from_finetune(finetune_cfg, mode: str)

Build an OmegaConf training config from finetune_config.yaml + mode name.

Source code in AlphaBrain/training/trainer_utils/finetune_config.py
def build_config_from_finetune(finetune_cfg, mode: str):
    """Build an OmegaConf training config from finetune_config.yaml + mode name."""
    all_modes = OmegaConf.to_container(finetune_cfg.modes, resolve=False)
    if mode not in all_modes:
        raise ValueError(f"Mode '{mode}' not found. Available: {list(all_modes.keys())}")

    # Work with a plain dict to avoid OmegaConf misinterpreting bash ${...} syntax
    mode_dict = OmegaConf.to_container(finetune_cfg.modes[mode], resolve=False)
    global_defaults = OmegaConf.to_container(finetune_cfg.get('defaults', {}), resolve=False)

    # ── 1. Base configs (model / dataset / trainer defaults) ──────────────────
    base_cfgs = []
    model_key    = mode_dict.get('model')    or global_defaults.get('model')
    dataset_key  = mode_dict.get('dataset')  or global_defaults.get('dataset')
    trainer_key  = mode_dict.get('trainer_defaults') or global_defaults.get('trainer')
    if model_key:   base_cfgs.append(OmegaConf.load(f"configs/models/{model_key}.yaml"))
    if dataset_key: base_cfgs.append(OmegaConf.load(f"configs/datasets/{dataset_key}.yaml"))
    if trainer_key: base_cfgs.append(OmegaConf.load(f"configs/trainer/{trainer_key}.yaml"))

    # Optional train recipe (backward compat; mode.config_yaml)
    recipe_path = mode_dict.get('config_yaml', '')
    if recipe_path and os.path.exists(recipe_path):
        recipe = OmegaConf.load(recipe_path)
        if '_model_config_' in recipe:
            recipe = OmegaConf.merge(OmegaConf.load(recipe.pop('_model_config_')), recipe)
        if 'defaults' in recipe:
            rd = recipe.pop('defaults')
            if 'model' in rd:
                recipe = OmegaConf.merge(OmegaConf.load(f"configs/models/{rd.model}.yaml"), recipe)
        base_cfgs.append(recipe)

    base = OmegaConf.merge(*base_cfgs) if base_cfgs else OmegaConf.create({})

    # ── 2. Global overrides from finetune_config (environment, seed) ──────────
    # NOTE: 'paths' is intentionally excluded — it's only for path resolution,
    #       not part of the training config, and its bash ${...} values would
    #       break OmegaConf interpolation resolution later.
    global_ov = {}
    for key in ('environment', 'seed'):
        if key in finetune_cfg:
            val = finetune_cfg[key]
            global_ov[key] = OmegaConf.to_container(val, resolve=False) if OmegaConf.is_config(val) else val

    # ── 3. Mode field mappings ─────────────────────────────────────────────────
    mode_ov = {}

    if 'run_id' in mode_dict:
        mode_ov['run_id'] = mode_dict['run_id']

    if 'output_root_dir' in mode_dict:
        mode_ov['output_root_dir'] = mode_dict['output_root_dir']
    elif 'common' in finetune_cfg and 'output_root_dir' in finetune_cfg.common:
        mode_ov['output_root_dir'] = finetune_cfg.common.output_root_dir

    if 'framework_name' in mode_dict:
        mode_ov.setdefault('framework', {})['name'] = mode_dict['framework_name']

    if 'base_vlm' in mode_dict:
        base_vlm = expand_env_vars(mode_dict['base_vlm'])
        # 预训练模型目录统一从环境变量 PRETRAINED_MODELS_DIR 读取
        pretrained_dir = os.environ.get('PRETRAINED_MODELS_DIR', 'data/pretrained_models')
        if not os.path.isabs(base_vlm) and not base_vlm.startswith('./') and not base_vlm.startswith('data/'):
            base_vlm = os.path.join(pretrained_dir, base_vlm)
        mode_ov.setdefault('framework', {}).setdefault('qwenvl', {})['base_vlm'] = base_vlm

    if 'data_root' in mode_dict:
        mode_ov.setdefault('datasets', {}).setdefault('vla_data', {})['data_root_dir'] = expand_env_vars(mode_dict['data_root'])
    if 'dataset_mix' in mode_dict:
        mode_ov.setdefault('datasets', {}).setdefault('vla_data', {})['dataset_mix'] = mode_dict['dataset_mix']

    training = mode_dict.get('training', {})
    for field in ('gradient_accumulation_steps', 'max_train_steps', 'save_interval', 'eval_interval', 'freeze_modules', 'pretrained_checkpoint'):
        if field in training:
            mode_ov.setdefault('trainer', {})[field] = training[field]
    if 'per_device_batch_size' in training:
        mode_ov.setdefault('datasets', {}).setdefault('vla_data', {})['per_device_batch_size'] = training['per_device_batch_size']

    # ── 4. Direct nested overrides (framework / datasets / trainer in mode) ───
    direct_ov = {k: mode_dict[k] for k in ('framework', 'datasets', 'trainer', 'trackers', 'wandb_project', 'wandb_entity', 'is_debug', 'stdp', 'lora') if k in mode_dict}

    def _recursive_expand_env(obj):
        """Recursively expand ${VAR} / ${VAR:-default} in all string values."""
        if isinstance(obj, str):
            return expand_env_vars(obj)
        elif isinstance(obj, dict):
            return {k: _recursive_expand_env(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [_recursive_expand_env(v) for v in obj]
        return obj

    direct_ov = _recursive_expand_env(direct_ov)

    # ── 5. Merge everything ───────────────────────────────────────────────────
    cfg = OmegaConf.merge(base, OmegaConf.create(global_ov), OmegaConf.create(mode_ov), OmegaConf.create(direct_ov))

    # ── 6. extra_args ─────────────────────────────────────────────────────────
    extra_args = mode_dict.get('extra_args', [])
    if extra_args:
        cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(normalize_dotlist_args(extra_args)))

    return cfg

Configuration tracker

config_tracker

AccessTrackedConfig
AccessTrackedConfig(cfg: Union[DictConfig, ListConfig], parent: AccessTrackedConfig = None, key_path: str = '')

Wrapper for OmegaConf to track accessed parameters. Only saves configuration items that were actually accessed during execution.

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def __init__(self, cfg: Union[DictConfig, ListConfig], parent: 'AccessTrackedConfig' = None, key_path: str = ""):
    object.__setattr__(self, '_cfg', cfg)
    object.__setattr__(self, '_parent', parent)
    object.__setattr__(self, '_key_path', key_path)
    object.__setattr__(self, '_local_accessed', set())
    object.__setattr__(self, '_children', {})

    if parent is None:
        AccessTrackedConfig._original_cfg_snapshot = OmegaConf.create(
            OmegaConf.to_container(cfg, resolve=True)
        )
keys
keys()

Return config keys (required for dict unpacking) Tracks all keys as accessed. Only works for DictConfig.

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def keys(self):
    """Return config keys (required for dict unpacking)
    Tracks all keys as accessed. Only works for DictConfig.
    """
    if self._is_list_config():
        raise TypeError("ListConfig does not support keys()")
    for key in self._cfg.keys():
        self._local_accessed.add(key)
    return self._cfg.keys()
values
values()

Return config values (tracks all keys as accessed)

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def values(self):
    """Return config values (tracks all keys as accessed)"""
    if self._is_list_config():
        for i in range(len(self._cfg)):
            self._local_accessed.add(f"[{i}]")
            yield self[i]
    else:
        for key in self._cfg.keys():
            self._local_accessed.add(key)
            yield self.get(key)
items
items()

Return config items (tracks all keys as accessed)

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def items(self):
    """Return config items (tracks all keys as accessed)"""
    if self._is_list_config():
        raise TypeError("ListConfig does not support items()")
    for key in self._cfg.keys():
        self._local_accessed.add(key)
        yield key, self.get(key)
get
get(key: str, default: Any = None) -> Any

Get value with default fallback

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def get(self, key: str, default: Any = None) -> Any:
    """Get value with default fallback"""
    self._local_accessed.add(key)
    value = self._cfg.get(key, default)

    if value is not default and OmegaConf.is_config(value):
        new_path = f"{self._key_path}.{key}" if self._key_path else key
        if key not in self._children:
            self._children[key] = AccessTrackedConfig(value, parent=self, key_path=new_path)
        return self._children[key]

    return value
update
update(other: Any = None, **kwargs)

Update config with values from another dict/config

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def update(self, other: Any = None, **kwargs):
    """Update config with values from another dict/config"""
    if self._is_list_config():
        raise TypeError("ListConfig does not support update()")

    if other is not None:
        # Handle different input types
        if isinstance(other, AccessTrackedConfig):
            other = OmegaConf.to_container(other._cfg, resolve=True)
        elif OmegaConf.is_config(other):
            other = OmegaConf.to_container(other, resolve=True)
        elif hasattr(other, 'items'):
            # Dict-like object
            other = dict(other.items())
        elif hasattr(other, '__iter__'):
            # Iterable of key-value pairs
            other = dict(other)
        else:
            raise TypeError(f"Cannot update from {type(other)}")

        for key, value in other.items():
            self._local_accessed.add(key)
            self._cfg[key] = value
            # Invalidate child cache if exists
            if key in self._children:
                del self._children[key]

    for key, value in kwargs.items():
        self._local_accessed.add(key)
        self._cfg[key] = value
        if key in self._children:
            del self._children[key]
pop
pop(key, *args)

Remove and return a value

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def pop(self, key, *args):
    """Remove and return a value"""
    if isinstance(key, int):
        self._local_accessed.add(f"[{key}]")
        cache_key = f"[{key}]"
    else:
        self._local_accessed.add(key)
        cache_key = key

    if cache_key in self._children:
        del self._children[cache_key]
    if args:
        return self._cfg.pop(key, args[0])
    return self._cfg.pop(key)
append
append(value: Any)

Append value to list (only for ListConfig)

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def append(self, value: Any):
    """Append value to list (only for ListConfig)"""
    if not self._is_list_config():
        raise TypeError("append() only supported for ListConfig")
    self._cfg.append(value)
    idx = len(self._cfg) - 1
    self._local_accessed.add(f"[{idx}]")
extend
extend(values)

Extend list with values (only for ListConfig)

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def extend(self, values):
    """Extend list with values (only for ListConfig)"""
    if not self._is_list_config():
        raise TypeError("extend() only supported for ListConfig")
    start_idx = len(self._cfg)
    self._cfg.extend(values)
    for i in range(start_idx, len(self._cfg)):
        self._local_accessed.add(f"[{i}]")
setdefault
setdefault(key: str, default: Any = None) -> Any

Set default value if key doesn't exist

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def setdefault(self, key: str, default: Any = None) -> Any:
    """Set default value if key doesn't exist"""
    if self._is_list_config():
        raise TypeError("ListConfig does not support setdefault()")
    self._local_accessed.add(key)
    if key not in self._cfg:
        self._cfg[key] = default
    return self.get(key)
copy
copy() -> AccessTrackedConfig

Return a shallow copy (does not copy access tracking state)

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def copy(self) -> 'AccessTrackedConfig':
    """Return a shallow copy (does not copy access tracking state)"""
    new_cfg = OmegaConf.create(OmegaConf.to_container(self._cfg, resolve=True))
    return AccessTrackedConfig(new_cfg)
deepcopy
deepcopy() -> AccessTrackedConfig

Return a deep copy (does not copy access tracking state)

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def deepcopy(self) -> 'AccessTrackedConfig':
    """Return a deep copy (does not copy access tracking state)"""
    new_cfg = OmegaConf.create(OmegaConf.to_container(self._cfg, resolve=True))
    return AccessTrackedConfig(new_cfg)
merge_with
merge_with(*others) -> AccessTrackedConfig

Merge with other configs and return new tracked config

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def merge_with(self, *others) -> 'AccessTrackedConfig':
    """Merge with other configs and return new tracked config"""
    configs = [self._cfg]
    for other in others:
        if isinstance(other, AccessTrackedConfig):
            configs.append(other._cfg)
        elif OmegaConf.is_config(other):
            configs.append(other)
        else:
            configs.append(OmegaConf.create(other))

    merged = OmegaConf.merge(*configs)
    return AccessTrackedConfig(merged)
to_dict
to_dict(resolve: bool = True) -> dict

Convert to plain dictionary or list

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def to_dict(self, resolve: bool = True) -> dict:
    """Convert to plain dictionary or list"""
    return OmegaConf.to_container(self._cfg, resolve=resolve)
to_yaml
to_yaml(resolve: bool = False) -> str

Convert to YAML string

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def to_yaml(self, resolve: bool = False) -> str:
    """Convert to YAML string"""
    return OmegaConf.to_yaml(self._cfg, resolve=resolve)
unwrap
unwrap() -> Union[DictConfig, ListConfig]

Get the underlying OmegaConf object

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def unwrap(self) -> Union[DictConfig, ListConfig]:
    """Get the underlying OmegaConf object"""
    return self._cfg
get_root
get_root() -> AccessTrackedConfig

Get root config object

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def get_root(self) -> 'AccessTrackedConfig':
    """Get root config object"""
    current = self
    while current._parent is not None:
        current = current._parent
    return current
export_accessed_config
export_accessed_config(use_original_values: bool = True) -> dict

Export accessed configuration as dictionary (only leaf values)

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def export_accessed_config(self, use_original_values: bool = True) -> dict:
    """Export accessed configuration as dictionary (only leaf values)"""
    all_paths = self._collect_all_paths()
    leaf_paths = self._filter_leaf_paths(all_paths)
    source_cfg = AccessTrackedConfig._original_cfg_snapshot if use_original_values else self.get_root()._cfg

    result = {}
    for path in sorted(leaf_paths):
        try:
            value = self._get_nested_value(source_cfg, path)
            self._set_nested_value(result, path, value)
        except Exception:
            if use_original_values:
                try:
                    value = self._get_nested_value(self.get_root()._cfg, path)
                    self._set_nested_value(result, path, value)
                except Exception:
                    pass
    return result
save_accessed_config
save_accessed_config(filepath: Path, use_original_values: bool = True)

Save accessed configuration to file

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def save_accessed_config(self, filepath: Path, use_original_values: bool = True):
    """Save accessed configuration to file"""
    accessed_config = self.export_accessed_config(use_original_values=use_original_values)
    filepath = Path(filepath)

    filepath.parent.mkdir(parents=True, exist_ok=True)

    with open(filepath, 'w') as f:
        if filepath.suffix == '.json':
            json.dump(accessed_config, f, indent=2)
        elif filepath.suffix in ('.yaml', '.yml'):
            OmegaConf.save(OmegaConf.create(accessed_config), f)
        else:
            raise ValueError(f"Unsupported file format: {filepath.suffix}")
get_access_summary
get_access_summary() -> dict

Get summary of accessed configuration

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def get_access_summary(self) -> dict:
    """Get summary of accessed configuration"""
    all_paths = self._collect_all_paths()
    leaf_paths = self._filter_leaf_paths(all_paths)

    return {
        "total_accessed_keys": len(all_paths),
        "leaf_accessed_keys": len(leaf_paths),
        "leaf_accessed_paths": sorted(leaf_paths),
        "top_level_keys": sorted(self.get_root()._local_accessed)
    }
print_access_summary
print_access_summary()

Print a formatted summary of accessed configuration

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def print_access_summary(self):
    """Print a formatted summary of accessed configuration"""
    summary = self.get_access_summary()
    print(f"\n{'='*60}")
    print("Configuration Access Summary")
    print(f"{'='*60}")
    print(f"Total accessed keys: {summary['total_accessed_keys']}")
    print(f"Leaf accessed keys: {summary['leaf_accessed_keys']}")
    print(f"\nTop-level keys accessed: {summary['top_level_keys']}")
    print(f"\nLeaf paths accessed:")
    for path in summary['leaf_accessed_paths']:
        print(f"  - {path}")
    print(f"{'='*60}\n")
wrap_config
wrap_config(cfg: OmegaConf) -> AccessTrackedConfig

Wrap OmegaConf configuration to enable access tracking

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def wrap_config(cfg: OmegaConf) -> AccessTrackedConfig:
    """Wrap OmegaConf configuration to enable access tracking"""
    return AccessTrackedConfig(cfg)
unwrap_config
unwrap_config(cfg) -> OmegaConf

Unwrap AccessTrackedConfig to get underlying OmegaConf object

Source code in AlphaBrain/training/trainer_utils/config_tracker.py
def unwrap_config(cfg) -> OmegaConf:
    """Unwrap AccessTrackedConfig to get underlying OmegaConf object"""
    return cfg.unwrap() if isinstance(cfg, AccessTrackedConfig) else cfg

Trainer helper functions

trainer_tools

metrics.py

Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various endpoints (e.g., JSONL local logs, Weights & Biases).

TrainerUtils
freeze_backbones staticmethod
freeze_backbones(model, freeze_modules='')

directly freeze the specified submodules based on the relative module path list (patterns), no longer recursively find all submodule names: - patterns: read from config.trainer.freeze_modules, separated by commas to get the "relative path" list for example "qwen_vl_interface, action_model.net", it means to freeze model.qwen_vl_interface and model.action_model.net.

Parameters:

Name Type Description Default
model

nn.Module model object

required
freeze_modules

relative module path list (patterns)

''

Returns:

Name Type Description
model

nn.Module model object

return: - model:

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def freeze_backbones(model, freeze_modules=""):
    """
    directly freeze the specified submodules based on the relative module path list (patterns), no longer recursively find all submodule names:
      - patterns: read from config.trainer.freeze_modules, separated by commas to get the "relative path" list
        for example "qwen_vl_interface, action_model.net",
        it means to freeze model.qwen_vl_interface and model.action_model.net.

    Args:
        model: nn.Module model object
        freeze_modules: relative module path list (patterns)

    Returns:
        model: nn.Module model object
    return:
      - model:
    """
    frozen = []
    if freeze_modules:
        if not dist.is_initialized() or dist.get_rank() == 0:
            print(f"🧊 freeze_modules: {freeze_modules}")
    if freeze_modules and type(freeze_modules) == str:
        # split and remove whitespace
        patterns = [p.strip() for p in freeze_modules.split(",") if p.strip()] if freeze_modules else []

        for path in patterns:
            # split the "relative path" by dots, for example "action_model.net" → ["action_model", "net"]
            attrs = path.split(".")
            module = model
            try:
                for attr in attrs:
                    module = getattr(module, attr)
                # if the module is successfully get, freeze it and its all submodule parameters
                for param in module.parameters():
                    param.requires_grad = False
                frozen.append(path)
            except AttributeError:
                # if the attribute does not exist, skip and print warning
                print(f"{_bold_yellow('[warn]')} module path not found, skipping freeze: {_yellow(path)}")
                continue

    # accelerator.wait_for_everyone()  # synchronize when distributed training
    if not dist.is_initialized() or dist.get_rank() == 0:
        print(f"{_cyan('[freeze]')} frozen: {_dim(str(frozen))}")
    return model
print_trainable_parameters staticmethod
print_trainable_parameters(model)

print the total number of parameters and trainable parameters of the model :param model: PyTorch model instance

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def print_trainable_parameters(model):
    """
    print the total number of parameters and trainable parameters of the model
    :param model: PyTorch model instance
    """
    if dist.is_initialized() and dist.get_rank() != 0:
        return
    num_params = sum(p.numel() for p in model.parameters())
    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(
        f"\033[1;96m[model]\033[0m  "
        f"Total \033[93m{num_params / 10**6:.2f}M\033[0m  "
        f"Trainable \033[1;93m{num_trainable_params / 10**6:.2f}M\033[0m"
    )
    return num_params, num_trainable_params
load_pretrained_backbones staticmethod
load_pretrained_backbones(model, checkpoint_path=None, reload_modules=None)

load checkpoint: - if reload_modules is set, load by path part - otherwise → load the entire model parameters (overwrite model)

return

replace, loaded_modules: list of module paths that successfully loaded parameters; if global load, then [""]

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def load_pretrained_backbones(model, checkpoint_path=None, reload_modules=None):
    """
    load checkpoint:
    - if reload_modules is set, load by path part
    - otherwise → load the entire model parameters (overwrite model)

    return:
        replace, loaded_modules: list of module paths that successfully loaded parameters; if global load, then ["<full_model>"]
    """
    if not checkpoint_path:
        return []
    if not dist.is_initialized() or dist.get_rank() == 0:
        print(f"{_cyan('[ckpt]')} loading {_dim(checkpoint_path)}")

    resolved_checkpoint_path = checkpoint_path
    if os.path.isdir(checkpoint_path):
        safetensors_path = os.path.join(checkpoint_path, "model.safetensors")
        pt_path = os.path.join(checkpoint_path, "pytorch_model.pt")
        if os.path.exists(safetensors_path):
            resolved_checkpoint_path = safetensors_path
        elif os.path.exists(pt_path):
            resolved_checkpoint_path = pt_path
        else:
            raise RuntimeError(
                f"{_bold_red('[error]')} checkpoint directory does not contain "
                f"`model.safetensors` or `pytorch_model.pt`: {checkpoint_path}"
            )

    try:
        if _is_safetensors_path(resolved_checkpoint_path):
            from safetensors.torch import load_file

            sf_path = str(checkpoint_path)
            if os.path.isdir(sf_path):
                sf_path = os.path.join(sf_path, "model.safetensors")
            checkpoint = load_file(sf_path)
        else:
            checkpoint = torch.load(resolved_checkpoint_path, map_location="cpu")
    except Exception as e:
        raise RuntimeError(f"{_bold_red('[error]')} loading checkpoint failed: {e}")

    loaded_modules = []

    if reload_modules:  # partial load
        module_paths = [p.strip() for p in reload_modules.split(",") if p.strip()]
        for path in module_paths:
            reload_modules = path.split(".")
            module = model
            try:
                for module_name in reload_modules:  # find the module to modify level by level
                    module = getattr(module, module_name)
                prefix = path + "."
                sub_state_dict = {k[len(prefix) :]: v for k, v in checkpoint.items() if k.startswith(prefix)}
                if sub_state_dict:
                    module.load_state_dict(sub_state_dict, strict=True)
                    if not dist.is_initialized() or dist.get_rank() == 0:
                        print(f"{_bold_green('[ok]')} loaded module {_yellow(repr(path))}")
                    loaded_modules.append(path)
                else:
                    print(f"{_bold_yellow('[warn]')} key not found in checkpoint: {_yellow(repr(path))}")
            except AttributeError:
                print(f"{_bold_red('[error]')} module path not found: {_yellow(repr(path))}")
    else:  # full load
        try:
            # Filter out shape-mismatched keys (e.g. action_dim 32→7)
            model_state = model.state_dict()
            filtered_checkpoint = {}
            skipped_keys = []
            for k, v in checkpoint.items():
                if k in model_state and model_state[k].shape != v.shape:
                    skipped_keys.append(f"{k}: ckpt {tuple(v.shape)} vs model {tuple(model_state[k].shape)}")
                else:
                    filtered_checkpoint[k] = v
            if skipped_keys and (not dist.is_initialized() or dist.get_rank() == 0):
                print(f"{_bold_yellow('[warn]')} skipped {len(skipped_keys)} shape-mismatched keys:")
                for sk in skipped_keys:
                    print(f"  {_dim(sk)}")
            model.load_state_dict(filtered_checkpoint, strict=False)
            if not dist.is_initialized() or dist.get_rank() == 0:
                print(f"{_bold_green('[ok]')} loaded {_bold_cyan('<full_model>')} parameters")
            loaded_modules = ["<full_model>"]
        except Exception as e:
            raise RuntimeError(f"{_bold_red('[error]')} loading full model failed: {e}")
    return model
print_freeze_status staticmethod
print_freeze_status(model)

print the freezing status of each parameter in the model :param model: PyTorch model instance

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def print_freeze_status(model):
    """
    print the freezing status of each parameter in the model
    :param model: PyTorch model instance
    """
    for name, param in model.named_parameters():
        status = "Frozen" if not param.requires_grad else "Trainable"
        print(f"{name:60s}  |  {status}")
setup_distributed_training staticmethod
setup_distributed_training(accelerator, *components)

use Accelerator to prepare distributed training components :param accelerator: Accelerate instance :param components: any number of components (such as model, optimizer, dataloader, etc.) :return: prepared distributed components (in the same order as input)

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def setup_distributed_training(accelerator, *components):
    """
    use Accelerator to prepare distributed training components
    :param accelerator: Accelerate instance
    :param components: any number of components (such as model, optimizer, dataloader, etc.)
    :return: prepared distributed components (in the same order as input)
    """

    # use accelerator.prepare method to wrap components
    prepared_components = accelerator.prepare(*components)

    # For DDP with parameter reuse (e.g. shared attention in PaliGemmaOFT),
    # set static_graph to allow parameters being used multiple times
    for comp in (prepared_components if isinstance(prepared_components, tuple) else [prepared_components]):
        if hasattr(comp, "module") and hasattr(comp, "_set_static_graph"):
            comp._set_static_graph()

    return prepared_components
compute_grad_angle_with_stats staticmethod
compute_grad_angle_with_stats(grads_a: list[Tensor], grads_v: list[Tensor]) -> Tuple[float, float]

compute the cosine angle between two groups of gradient vectors (degrees), and calculate the average angle and variance. grads_a, grads_v: gradient Tensor list corresponding to the same parameter list interface_params return: mean_angle_deg: average angle (degrees) angle_variance: angle variance

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def compute_grad_angle_with_stats(grads_a: list[torch.Tensor], grads_v: list[torch.Tensor]) -> Tuple[float, float]:
    """
    compute the cosine angle between two groups of gradient vectors (degrees), and calculate the average angle and variance.
    grads_a, grads_v: gradient Tensor list corresponding to the same parameter list interface_params
    return:
        mean_angle_deg: average angle (degrees)
        angle_variance: angle variance
    """
    angle_degs = []

    # compute the cosine angle between each gradient block grads_a[0].shape = 1280, 3, 14, 14
    # grads_1 = grads_a[0][0]  # [3, 14, 14]
    # grads_2 = grads_v[0][0]
    # grads_a = grads_1.view(-1, 3)  # reshape to [196, 3]
    # grads_v = grads_2.view(-1, 3)

    # lang linear
    # reshape to 14*14, 3
    # layer
    grads_action = grads_a[0]  # [2048, 11008]
    grads_action = grads_action[
        :32, :7
    ]  # only take the first 7 elements, avoid cosim failure in high-dimensional space
    grads_vl = grads_v[0]  # [2048, 11008]
    grads_vl = grads_vl[
        :32, :7
    ]  # only take the first 32 elements, 7 dimensions, avoid cosim failure in high-dimensional space
    for g_a, g_v in zip(grads_action, grads_vl):
        dot = torch.sum(g_a * g_v)
        norm_a_sq = torch.sum(g_a * g_a)
        norm_v_sq = torch.sum(g_v * g_v)

        # avoid division by zero
        norm_a = torch.sqrt(norm_a_sq + 1e-16)
        norm_v = torch.sqrt(norm_v_sq + 1e-16)

        cos_sim = (dot / (norm_a * norm_v)).clamp(-1.0, 1.0)
        angle_rad = torch.acos(cos_sim)
        angle_deg = angle_rad * (180.0 / torch.pi)

        angle_degs.append(angle_deg.item())

    # compute the average angle and variance
    angle_degs_tensor = torch.tensor(angle_degs)
    mean_angle_deg = torch.mean(angle_degs_tensor).item()
    angle_variance = torch.sqrt(torch.var(angle_degs_tensor)).item()
    # accelerator.wait_for_everyone()
    return mean_angle_deg, angle_variance
pcgrad_project staticmethod
pcgrad_project(grads_a: list[Tensor], grads_v: list[Tensor]) -> list[torch.Tensor]

apply PCGrad projection to the second group of gradients grads_v, suppress negative transfer between grads_a and grads_v if the dot product of two groups of gradients < 0, then: grads_v <- grads_v - (dot / ||grads_a||^2) * grads_a return the new grads_v list

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def pcgrad_project(grads_a: list[torch.Tensor], grads_v: list[torch.Tensor]) -> list[torch.Tensor]:
    """
    apply PCGrad projection to the second group of gradients grads_v, suppress negative transfer between grads_a and grads_v
    if the dot product of two groups of gradients < 0, then:
        grads_v <- grads_v - (dot / ||grads_a||^2) * grads_a
    return the new grads_v list
    """
    # first compute dot and ||grads_a||^2
    dot, norm_a_sq = 0.0, 0.0
    for g_a, g_v in zip(grads_a, grads_v):
        dot += torch.sum(g_a * g_v)
        norm_a_sq += torch.sum(g_a * g_a)

    if dot < 0:
        coeff = dot / (norm_a_sq + 1e-6)
        # projection
        grads_v = [g_v - coeff * g_a for g_a, g_v in zip(grads_a, grads_v)]

    return grads_v
eval_qwenpi staticmethod
eval_qwenpi(qwenpi, dataloader, num_batches=20)

evaluate QwenQFormerDiT model, compute IoU and action distance.

Parameters:

Name Type Description Default
qwenpi

QwenQFormerDiT model instance.

required
dataloader

data loader.

required
num_batches

number of batches to evaluate.

20

Returns:

Name Type Description
dict

contains IoU and action distance evaluation results.

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def eval_qwenpi(qwenpi, dataloader, num_batches=20):
    """
    evaluate QwenQFormerDiT model, compute IoU and action distance.

    Args:
        qwenpi: QwenQFormerDiT model instance.
        dataloader: data loader.
        num_batches: number of batches to evaluate.

    Returns:
        dict: contains IoU and action distance evaluation results.
    """
    iou_scores = []
    action_distances = []
    count = 0

    dataset_iter = iter(dataloader)
    while count < num_batches:
        try:
            batch_samples = next(dataset_iter)
            count += 1
        except StopIteration:
            break

        # extract data
        images = [example["image"] for example in batch_samples]
        instructions = [example["lang"] for example in batch_samples]
        actions = [example["action"] for example in batch_samples]
        solutions = [example["solution"] for example in batch_samples]

        # model prediction
        predicted_solutions, normalized_actions = qwenpi.predict_action_withCoT(
            images=images, instructions=instructions, use_ddim=False, num_ddim_steps=20
        )

        # extract and convert predicted results
        parsed_solutions = []
        for solution in predicted_solutions:
            parsed_solution = TrainerUtils.extract_json_from_string(solution)
            parsed_solutions.append(parsed_solution)

        # compute IoU
        for pred_dict, gt_dict in zip(parsed_solutions, solutions):
            pred_pick_bbox = torch.tensor(pred_dict["pick"]["bbox_2d"], dtype=torch.float32).unsqueeze(0)
            gt_pick_bbox = torch.tensor(gt_dict["pick"]["bbox_2d"], dtype=torch.float32).unsqueeze(0)
            pred_place_bbox = torch.tensor(pred_dict["place"]["bbox_2d"], dtype=torch.float32).unsqueeze(0)
            gt_place_bbox = torch.tensor(gt_dict["place"]["bbox_2d"], dtype=torch.float32).unsqueeze(0)

            pick_iou = box_iou(pred_pick_bbox, gt_pick_bbox).item()
            place_iou = box_iou(pred_place_bbox, gt_place_bbox).item()

            iou_scores.append({"pick_iou": pick_iou, "place_iou": place_iou})

        # compute action distance
        actions = np.array(actions)  # convert to numpy array
        num_elements = np.prod(actions.shape)  # B*len*dim
        action_distance = TrainerUtils.euclidean_distance(normalized_actions, actions)
        average_action_distance = action_distance / num_elements
        action_distances.append(average_action_distance)

    # summarize results
    avg_action_distance = np.mean(action_distances)
    return {"iou_scores": iou_scores, "average_action_distance": avg_action_distance}
extract_json_from_string staticmethod
extract_json_from_string(input_string)

extract valid JSON part from string and convert to dictionary.

Parameters:

Name Type Description Default
input_string str

string containing extra characters.

required

Returns:

Name Type Description
dict

dictionary extracted and parsed.

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
@staticmethod
def extract_json_from_string(input_string):
    """
    extract valid JSON part from string and convert to dictionary.

    Args:
        input_string (str): string containing extra characters.

    Returns:
        dict: dictionary extracted and parsed.
    """
    json_match = re.search(r"{.*}", input_string, re.DOTALL)
    if json_match:
        json_str = json_match.group(0)
        try:
            return json.loads(json_str)
        except json.JSONDecodeError as e:
            print(f"JSON decode failed: {e}")
            return None
    else:
        print("No valid JSON part found")
        return None
normalize_dotlist_args
normalize_dotlist_args(args)

Convert ['--x.y', 'val'] and ['--flag'] → ['x.y=val', 'flag=true']

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
def normalize_dotlist_args(args):
    """
    Convert ['--x.y', 'val'] and ['--flag'] → ['x.y=val', 'flag=true']
    """
    normalized = []
    skip = False
    for i in range(len(args)):
        if skip:
            skip = False
            continue

        arg = args[i]
        if arg.startswith("--"):
            key = arg.lstrip("-")
            if "=" in key:
                normalized.append(key)
            elif i + 1 < len(args) and not args[i + 1].startswith("--"):
                normalized.append(f"{key}={args[i + 1]}")
                skip = True
            else:
                normalized.append(f"{key}=true")
        elif "=" in arg:
            # Bare dotlist format: key=value (Hydra/OmegaConf style)
            normalized.append(arg)
        else:
            pass  # skip orphaned values
    return normalized
build_param_lr_groups
build_param_lr_groups(model, cfg)

build multiple param groups based on cfg.trainer.learning_rate. support specifying different learning rates for different modules, the rest use base.

Parameters:

Name Type Description Default
vla

nn.Module model object

required
cfg

config object, requires cfg.trainer.learning_rate dictionary

required

Returns:

Type Description

List[Dict]: param_groups that can be used to build optimizer with torch.optim

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
def build_param_lr_groups(model, cfg):
    """
    build multiple param groups based on cfg.trainer.learning_rate.
    support specifying different learning rates for different modules, the rest use base.

    Args:
        vla: nn.Module model object
        cfg: config object, requires cfg.trainer.learning_rate dictionary

    Returns:
        List[Dict]: param_groups that can be used to build optimizer with torch.optim
    """

    lr_cfg = cfg.trainer.learning_rate
    base_lr = lr_cfg.get("base", 1e-4)  # default base learning rate

    freeze_modules = cfg.trainer.get("freeze_modules", "")
    if not isinstance(freeze_modules, str):
        freeze_modules = ""
    freeze_patterns = [p.strip() for p in freeze_modules.split(",") if p.strip()]

    used_params = set()
    frozen_params = set()
    param_groups = []

    for freeze_path in freeze_patterns:
        module = model
        try:
            for attr in freeze_path.split("."):
                module = getattr(module, attr)
            frozen_params.update(id(p) for p in module.parameters())
        except AttributeError:
            print(f"{_bold_yellow('[warn]')} freeze path not found: {_dim(str(freeze_path))}")
            continue

    for module_name, lr in lr_cfg.items():
        if module_name == "base":
            continue
        # try to find the module under vla by module_name (support nested paths)
        module = model
        try:
            for attr in module_name.split("."):
                module = getattr(module, attr)
            # filter out frozen parameters (config-based and requires_grad-based)
            params = [p for p in module.parameters() if id(p) not in frozen_params and p.requires_grad]
            if params:  # only add param group if there are trainable parameters
                param_groups.append({"params": params, "lr": lr, "name": module_name})
                used_params.update(id(p) for p in params)
        except AttributeError:
            ReferenceError(f"⚠️ module path `{module_name}` not found in vla")

    # assign base learning rate to the remaining unused parameters (exclude frozen ones)
    other_params = [p for p in model.parameters() if id(p) not in used_params and id(p) not in frozen_params and p.requires_grad]
    if other_params:
        param_groups.append({"params": other_params, "lr": base_lr, "name": "base"})

    return param_groups
only_main_process
only_main_process(func)

decorator: only run in main process (rank=0)

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
def only_main_process(func):
    """
    decorator: only run in main process (rank=0)
    """

    def wrapper(*args, **kwargs):
        if dist.is_initialized() and dist.get_rank() != 0:
            return None  # non-main process does not execute
        return func(*args, **kwargs)

    return wrapper
resize_images
resize_images(images, target_size=(224, 224))

recursively resize all images in the nested list.

:param images: nested list of images or single image. :param target_size: target size (width, height) after resizing. :return: resized images list, keeping the original nested structure.

Source code in AlphaBrain/training/trainer_utils/trainer_tools.py
def resize_images(images, target_size=(224, 224)):
    """
    recursively resize all images in the nested list.

    :param images: nested list of images or single image.
    :param target_size: target size (width, height) after resizing.
    :return: resized images list, keeping the original nested structure.
    """
    if isinstance(images, np.ndarray):  # numpy array -> convert to PIL first
        return Image.fromarray(images).resize(target_size)
    if isinstance(images, Image.Image):  # if it is a single PIL image
        return images.resize(target_size)
    elif isinstance(images, list):  # if it is a list, recursively process each element
        return [resize_images(img, target_size) for img in images]
    else:
        raise ValueError("Unsupported image type or structure.")

PEFT integration

peft

LoRA / PEFT helpers shared across all trainers.

Public API
is_lora_enabled(cfg)                            -> bool
apply_lora(model, cfg)                          -> model (in-place)
save_lora_checkpoint(accelerator, model, base_path, cfg)
load_and_merge(base_model_factory, lora_adapter_dir,
               action_model_pt, output_path, vlm_module=None)
Schema

The lora: block in training yaml is parsed by LoRASpec.from_omega. See config.py for the recognized fields. Backward-compatible with all existing yaml configs under configs/continual_learning/.

Checkpoint layout (unchanged from previous inline implementation): _lora_adapter/ PEFT adapter directory _action_model.pt Non-VLM trainable weights

LoRASpec dataclass
LoRASpec(rank: int = 32, alpha: int = 16, dropout: float = 0.05, target_modules: Any = 'all-linear', init_lora_weights: str = 'gaussian', vlm_module: str | None = None, freeze_extra_modules: list[str] = list())

Backbone-agnostic LoRA application spec.

Resolved from yaml lora: block via :meth:from_omega.

from_omega classmethod
from_omega(cfg: Any) -> 'LoRASpec'

Parse from yaml/OmegaConf lora: block.

Tolerant of: - Missing lora key (returns defaults; caller should check is_lora_enabled) - freeze_extra_modules as comma-separated string OR list - target_modules as string ("all-linear") OR list

Source code in AlphaBrain/training/trainer_utils/peft/config.py
@classmethod
def from_omega(cls, cfg: Any) -> "LoRASpec":
    """Parse from yaml/OmegaConf `lora:` block.

    Tolerant of:
    - Missing `lora` key (returns defaults; caller should check `is_lora_enabled`)
    - `freeze_extra_modules` as comma-separated string OR list
    - `target_modules` as string ("all-linear") OR list
    """
    lora_cfg = cfg.get("lora", {}) if hasattr(cfg, "get") else getattr(cfg, "lora", {})
    if lora_cfg is None:
        lora_cfg = {}
    get = lora_cfg.get if hasattr(lora_cfg, "get") else (lambda k, d=None: getattr(lora_cfg, k, d))

    freeze_extra = get("freeze_extra_modules", []) or []
    if isinstance(freeze_extra, str):
        freeze_extra = [m.strip() for m in freeze_extra.split(",") if m.strip()]
    elif isinstance(freeze_extra, (list, tuple)):
        freeze_extra = list(freeze_extra)
    else:
        # OmegaConf ListConfig
        try:
            freeze_extra = list(freeze_extra)
        except TypeError:
            freeze_extra = []

    target = get("target_modules", "all-linear")
    # OmegaConf list -> python list
    if not isinstance(target, str):
        try:
            target = list(target)
        except TypeError:
            pass

    return cls(
        rank=int(get("rank", 32)),
        alpha=int(get("alpha", 16)),
        dropout=float(get("dropout", 0.05)),
        target_modules=target,
        init_lora_weights=str(get("init_lora_weights", "gaussian")),
        vlm_module=get("vlm_module", None),
        freeze_extra_modules=freeze_extra,
    )
peft_config
peft_config()

Build a peft.LoraConfig from this spec.

Source code in AlphaBrain/training/trainer_utils/peft/config.py
def peft_config(self):
    """Build a `peft.LoraConfig` from this spec."""
    from peft import LoraConfig
    return LoraConfig(
        r=self.rank,
        lora_alpha=self.alpha,
        lora_dropout=self.dropout,
        target_modules=self.target_modules,
        init_lora_weights=self.init_lora_weights,
    )
is_lora_enabled
is_lora_enabled(cfg: Any) -> bool

Return True iff cfg.lora.enabled is set.

Source code in AlphaBrain/training/trainer_utils/peft/config.py
def is_lora_enabled(cfg: Any) -> bool:
    """Return True iff `cfg.lora.enabled` is set."""
    if cfg is None:
        return False
    lora_cfg = cfg.get("lora", {}) if hasattr(cfg, "get") else getattr(cfg, "lora", {})
    if isinstance(lora_cfg, DictConfig):
        return bool(lora_cfg.get("enabled", False))
    return bool(getattr(lora_cfg, "enabled", False) if lora_cfg else False)
apply_lora
apply_lora(model: Module, cfg: Any, *, print_summary: bool = True) -> nn.Module

Apply LoRA in-place per spec.

Steps
  1. Resolve VLM interface (from lora.vlm_module or auto-detect via _VLM_REGISTRY).
  2. Freeze ALL params of the VLM interface wrapper.
  3. Replace vlm_interface.model = get_peft_model(...) so PEFT injects LoRA layers (their params are trainable, base remains frozen).
  4. Freeze each module listed in lora.freeze_extra_modules.
  5. Modules not touched by the above stay with their original requires_grad (typically full-FT — e.g. action_model, dino).

Returns the same model instance (mutated in place).

Source code in AlphaBrain/training/trainer_utils/peft/injector.py
def apply_lora(
    model: nn.Module,
    cfg: Any,
    *,
    print_summary: bool = True,
) -> nn.Module:
    """Apply LoRA in-place per spec.

    Steps:
      1. Resolve VLM interface (from `lora.vlm_module` or auto-detect via
         `_VLM_REGISTRY`).
      2. Freeze ALL params of the VLM interface wrapper.
      3. Replace `vlm_interface.model = get_peft_model(...)` so PEFT injects
         LoRA layers (their params are trainable, base remains frozen).
      4. Freeze each module listed in `lora.freeze_extra_modules`.
      5. Modules not touched by the above stay with their original
         `requires_grad` (typically full-FT — e.g. action_model, dino).

    Returns the same `model` instance (mutated in place).
    """
    if not is_lora_enabled(cfg):
        return model

    from peft import get_peft_model
    from AlphaBrain.model.framework.base_framework import _detect_vlm_interface

    spec = LoRASpec.from_omega(cfg)
    lora_config = spec.peft_config()

    # 1. Resolve VLM interface
    if spec.vlm_module:
        if not hasattr(model, spec.vlm_module):
            raise AttributeError(
                f"lora.vlm_module='{spec.vlm_module}' not found on model "
                f"(available: {[n for n, _ in model.named_children()]})"
            )
        vlm_interface = getattr(model, spec.vlm_module)
    else:
        vlm_interface = _detect_vlm_interface(model)
    assert vlm_interface is not None, (
        "No VLM interface found for LoRA injection. "
        "Set lora.vlm_module explicitly in config."
    )

    # 2 + 3. Freeze backbone, inject PEFT
    for p in vlm_interface.parameters():
        p.requires_grad = False
    vlm_interface.model = get_peft_model(vlm_interface.model, lora_config)

    # 4. Freeze extras
    for module_name in spec.freeze_extra_modules:
        if hasattr(model, module_name):
            extra_module = getattr(model, module_name)
            n = sum(1 for _ in extra_module.parameters())
            for p in extra_module.parameters():
                p.requires_grad = False
            logger.info(f"Froze extra module '{module_name}' ({n} param tensors)")
        else:
            logger.warning(f"freeze_extra_modules: '{module_name}' not found, skipping")

    # 5. Logging
    if print_summary:
        logger.info("LoRA enabled on VLM backbone")
        try:
            vlm_interface.model.print_trainable_parameters()
        except Exception:
            pass
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        logger.info(
            f"Total trainable: {trainable / 1e6:.1f}M / {total / 1e6:.1f}M "
            f"({100 * trainable / max(total, 1):.2f}%)"
        )

    return model
save_lora_checkpoint
save_lora_checkpoint(*, accelerator, model, base_path: str, cfg: Any) -> None

Save LoRA adapter + non-VLM weights for a checkpoint.

Creates

_lora_adapter/ (PEFT adapter) _action_model.pt (all keys NOT starting with a VLM attr prefix)

Source code in AlphaBrain/training/trainer_utils/peft/checkpoint.py
def save_lora_checkpoint(
    *,
    accelerator,
    model,
    base_path: str,
    cfg: Any,
) -> None:
    """Save LoRA adapter + non-VLM weights for a checkpoint.

    Creates:
      <base_path>_lora_adapter/        (PEFT adapter)
      <base_path>_action_model.pt      (all keys NOT starting with a VLM attr prefix)
    """
    unwrapped = accelerator.unwrap_model(model)
    vlm_module = (
        cfg.get("lora", {}).get("vlm_module")
        if hasattr(cfg, "get")
        else getattr(getattr(cfg, "lora", None), "vlm_module", None)
    )

    # 1. Adapter
    vlm_interface = _resolve_vlm_interface(unwrapped, vlm_module)
    adapter_path = base_path + "_lora_adapter"
    vlm_interface.model.save_pretrained(adapter_path)

    # 2. Non-VLM weights
    vlm_prefixes = _vlm_attr_prefixes(unwrapped)
    state_dict = accelerator.get_state_dict(model)
    non_vlm_state = {k: v for k, v in state_dict.items() if not k.startswith(vlm_prefixes)}
    torch.save(non_vlm_state, base_path + "_action_model.pt")

    logger.info(
        f"LoRA checkpoint saved: {adapter_path} + non-VLM weights "
        f"({len(non_vlm_state)} keys)"
    )
load_and_merge
load_and_merge(*, base_model_factory: Callable[[], 'torch.nn.Module'], lora_adapter_dir: str, action_model_pt: str, output_path: str, vlm_module: str | None = None) -> None

Build base model, attach LoRA adapter, merge, load extras, save full ckpt.

The output is a single .pt file usable by BaseFramework.from_pretrained, suitable for the standard server_policy + eval_libero pipeline.

Source code in AlphaBrain/training/trainer_utils/peft/checkpoint.py
def load_and_merge(
    *,
    base_model_factory: Callable[[], "torch.nn.Module"],
    lora_adapter_dir: str,
    action_model_pt: str,
    output_path: str,
    vlm_module: str | None = None,
) -> None:
    """Build base model, attach LoRA adapter, merge, load extras, save full ckpt.

    The output is a single `.pt` file usable by `BaseFramework.from_pretrained`,
    suitable for the standard server_policy + eval_libero pipeline.
    """
    from peft import PeftModel

    print(f"[1/4] Build base model")
    vla = base_model_factory()

    print(f"[2/4] Attach + merge LoRA adapter from {lora_adapter_dir}")
    vlm_interface = _resolve_vlm_interface(vla, vlm_module)
    vlm_interface.model = PeftModel.from_pretrained(
        vlm_interface.model,
        lora_adapter_dir,
    )
    vlm_interface.model = vlm_interface.model.merge_and_unload()
    print("  LoRA merged into VLM backbone")

    print(f"[3/4] Load non-VLM weights from {action_model_pt}")
    non_vlm_state = torch.load(action_model_pt, map_location="cpu")
    missing, unexpected = vla.load_state_dict(non_vlm_state, strict=False)
    if unexpected:
        print(f"  WARNING: unexpected keys: {unexpected[:5]}...")
    print(
        f"  Loaded {len(non_vlm_state)} non-VLM keys "
        f"(missing {len(missing)} VLM keys as expected — recovered via LoRA merge)"
    )

    print(f"[4/4] Save merged checkpoint to {output_path}")
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    full_state = vla.state_dict()
    torch.save(full_state, output_path)
    size_mb = os.path.getsize(output_path) / (1024 * 1024)
    print(f"  Done! Merged checkpoint: {size_mb:.0f} MB")
checkpoint

LoRA checkpoint save / load+merge.

File-name conventions are kept identical to the previous inline code, so existing checkpoints (5d / 5h / 5l etc.) remain merge-and-eval compatible:

<base_path>_lora_adapter/        ← PEFT adapter directory
  adapter_config.json
  adapter_model.safetensors
<base_path>_action_model.pt      ← non-VLM weights (action_model + extras
                                   like layer_qformer / edit_model / dino)
save_lora_checkpoint
save_lora_checkpoint(*, accelerator, model, base_path: str, cfg: Any) -> None

Save LoRA adapter + non-VLM weights for a checkpoint.

Creates

_lora_adapter/ (PEFT adapter) _action_model.pt (all keys NOT starting with a VLM attr prefix)

Source code in AlphaBrain/training/trainer_utils/peft/checkpoint.py
def save_lora_checkpoint(
    *,
    accelerator,
    model,
    base_path: str,
    cfg: Any,
) -> None:
    """Save LoRA adapter + non-VLM weights for a checkpoint.

    Creates:
      <base_path>_lora_adapter/        (PEFT adapter)
      <base_path>_action_model.pt      (all keys NOT starting with a VLM attr prefix)
    """
    unwrapped = accelerator.unwrap_model(model)
    vlm_module = (
        cfg.get("lora", {}).get("vlm_module")
        if hasattr(cfg, "get")
        else getattr(getattr(cfg, "lora", None), "vlm_module", None)
    )

    # 1. Adapter
    vlm_interface = _resolve_vlm_interface(unwrapped, vlm_module)
    adapter_path = base_path + "_lora_adapter"
    vlm_interface.model.save_pretrained(adapter_path)

    # 2. Non-VLM weights
    vlm_prefixes = _vlm_attr_prefixes(unwrapped)
    state_dict = accelerator.get_state_dict(model)
    non_vlm_state = {k: v for k, v in state_dict.items() if not k.startswith(vlm_prefixes)}
    torch.save(non_vlm_state, base_path + "_action_model.pt")

    logger.info(
        f"LoRA checkpoint saved: {adapter_path} + non-VLM weights "
        f"({len(non_vlm_state)} keys)"
    )
load_and_merge
load_and_merge(*, base_model_factory: Callable[[], 'torch.nn.Module'], lora_adapter_dir: str, action_model_pt: str, output_path: str, vlm_module: str | None = None) -> None

Build base model, attach LoRA adapter, merge, load extras, save full ckpt.

The output is a single .pt file usable by BaseFramework.from_pretrained, suitable for the standard server_policy + eval_libero pipeline.

Source code in AlphaBrain/training/trainer_utils/peft/checkpoint.py
def load_and_merge(
    *,
    base_model_factory: Callable[[], "torch.nn.Module"],
    lora_adapter_dir: str,
    action_model_pt: str,
    output_path: str,
    vlm_module: str | None = None,
) -> None:
    """Build base model, attach LoRA adapter, merge, load extras, save full ckpt.

    The output is a single `.pt` file usable by `BaseFramework.from_pretrained`,
    suitable for the standard server_policy + eval_libero pipeline.
    """
    from peft import PeftModel

    print(f"[1/4] Build base model")
    vla = base_model_factory()

    print(f"[2/4] Attach + merge LoRA adapter from {lora_adapter_dir}")
    vlm_interface = _resolve_vlm_interface(vla, vlm_module)
    vlm_interface.model = PeftModel.from_pretrained(
        vlm_interface.model,
        lora_adapter_dir,
    )
    vlm_interface.model = vlm_interface.model.merge_and_unload()
    print("  LoRA merged into VLM backbone")

    print(f"[3/4] Load non-VLM weights from {action_model_pt}")
    non_vlm_state = torch.load(action_model_pt, map_location="cpu")
    missing, unexpected = vla.load_state_dict(non_vlm_state, strict=False)
    if unexpected:
        print(f"  WARNING: unexpected keys: {unexpected[:5]}...")
    print(
        f"  Loaded {len(non_vlm_state)} non-VLM keys "
        f"(missing {len(missing)} VLM keys as expected — recovered via LoRA merge)"
    )

    print(f"[4/4] Save merged checkpoint to {output_path}")
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    full_state = vla.state_dict()
    torch.save(full_state, output_path)
    size_mb = os.path.getsize(output_path) / (1024 * 1024)
    print(f"  Done! Merged checkpoint: {size_mb:.0f} MB")
config

LoRA spec parsed from yaml lora: section.

Recognized fields (current schema, kept stable for backward-compat): enabled bool rank int (default 32) alpha int (default 16) dropout float (default 0.05) target_modules str | list[str] (default "all-linear") init_lora_weights str (default "gaussian") vlm_module str | None (default None → auto-detect) freeze_extra_modules str | list[str] (default [])

LoRASpec dataclass
LoRASpec(rank: int = 32, alpha: int = 16, dropout: float = 0.05, target_modules: Any = 'all-linear', init_lora_weights: str = 'gaussian', vlm_module: str | None = None, freeze_extra_modules: list[str] = list())

Backbone-agnostic LoRA application spec.

Resolved from yaml lora: block via :meth:from_omega.

from_omega classmethod
from_omega(cfg: Any) -> 'LoRASpec'

Parse from yaml/OmegaConf lora: block.

Tolerant of: - Missing lora key (returns defaults; caller should check is_lora_enabled) - freeze_extra_modules as comma-separated string OR list - target_modules as string ("all-linear") OR list

Source code in AlphaBrain/training/trainer_utils/peft/config.py
@classmethod
def from_omega(cls, cfg: Any) -> "LoRASpec":
    """Parse from yaml/OmegaConf `lora:` block.

    Tolerant of:
    - Missing `lora` key (returns defaults; caller should check `is_lora_enabled`)
    - `freeze_extra_modules` as comma-separated string OR list
    - `target_modules` as string ("all-linear") OR list
    """
    lora_cfg = cfg.get("lora", {}) if hasattr(cfg, "get") else getattr(cfg, "lora", {})
    if lora_cfg is None:
        lora_cfg = {}
    get = lora_cfg.get if hasattr(lora_cfg, "get") else (lambda k, d=None: getattr(lora_cfg, k, d))

    freeze_extra = get("freeze_extra_modules", []) or []
    if isinstance(freeze_extra, str):
        freeze_extra = [m.strip() for m in freeze_extra.split(",") if m.strip()]
    elif isinstance(freeze_extra, (list, tuple)):
        freeze_extra = list(freeze_extra)
    else:
        # OmegaConf ListConfig
        try:
            freeze_extra = list(freeze_extra)
        except TypeError:
            freeze_extra = []

    target = get("target_modules", "all-linear")
    # OmegaConf list -> python list
    if not isinstance(target, str):
        try:
            target = list(target)
        except TypeError:
            pass

    return cls(
        rank=int(get("rank", 32)),
        alpha=int(get("alpha", 16)),
        dropout=float(get("dropout", 0.05)),
        target_modules=target,
        init_lora_weights=str(get("init_lora_weights", "gaussian")),
        vlm_module=get("vlm_module", None),
        freeze_extra_modules=freeze_extra,
    )
peft_config
peft_config()

Build a peft.LoraConfig from this spec.

Source code in AlphaBrain/training/trainer_utils/peft/config.py
def peft_config(self):
    """Build a `peft.LoraConfig` from this spec."""
    from peft import LoraConfig
    return LoraConfig(
        r=self.rank,
        lora_alpha=self.alpha,
        lora_dropout=self.dropout,
        target_modules=self.target_modules,
        init_lora_weights=self.init_lora_weights,
    )
is_lora_enabled
is_lora_enabled(cfg: Any) -> bool

Return True iff cfg.lora.enabled is set.

Source code in AlphaBrain/training/trainer_utils/peft/config.py
def is_lora_enabled(cfg: Any) -> bool:
    """Return True iff `cfg.lora.enabled` is set."""
    if cfg is None:
        return False
    lora_cfg = cfg.get("lora", {}) if hasattr(cfg, "get") else getattr(cfg, "lora", {})
    if isinstance(lora_cfg, DictConfig):
        return bool(lora_cfg.get("enabled", False))
    return bool(getattr(lora_cfg, "enabled", False) if lora_cfg else False)
injector

LoRA injection: freeze backbone + wrap with PEFT + freeze extras.

Extracted verbatim from the (more complete) implementation that previously lived in AlphaBrain/training/continual_learning/train.py. The simpler implementation in AlphaBrain/training/train_alphabrain.py is replaced by this version (a strict superset — the auto-detect / freeze_extras paths are no-op when the relevant yaml fields are absent, so QwenGR00T behavior is identical).

apply_lora
apply_lora(model: Module, cfg: Any, *, print_summary: bool = True) -> nn.Module

Apply LoRA in-place per spec.

Steps
  1. Resolve VLM interface (from lora.vlm_module or auto-detect via _VLM_REGISTRY).
  2. Freeze ALL params of the VLM interface wrapper.
  3. Replace vlm_interface.model = get_peft_model(...) so PEFT injects LoRA layers (their params are trainable, base remains frozen).
  4. Freeze each module listed in lora.freeze_extra_modules.
  5. Modules not touched by the above stay with their original requires_grad (typically full-FT — e.g. action_model, dino).

Returns the same model instance (mutated in place).

Source code in AlphaBrain/training/trainer_utils/peft/injector.py
def apply_lora(
    model: nn.Module,
    cfg: Any,
    *,
    print_summary: bool = True,
) -> nn.Module:
    """Apply LoRA in-place per spec.

    Steps:
      1. Resolve VLM interface (from `lora.vlm_module` or auto-detect via
         `_VLM_REGISTRY`).
      2. Freeze ALL params of the VLM interface wrapper.
      3. Replace `vlm_interface.model = get_peft_model(...)` so PEFT injects
         LoRA layers (their params are trainable, base remains frozen).
      4. Freeze each module listed in `lora.freeze_extra_modules`.
      5. Modules not touched by the above stay with their original
         `requires_grad` (typically full-FT — e.g. action_model, dino).

    Returns the same `model` instance (mutated in place).
    """
    if not is_lora_enabled(cfg):
        return model

    from peft import get_peft_model
    from AlphaBrain.model.framework.base_framework import _detect_vlm_interface

    spec = LoRASpec.from_omega(cfg)
    lora_config = spec.peft_config()

    # 1. Resolve VLM interface
    if spec.vlm_module:
        if not hasattr(model, spec.vlm_module):
            raise AttributeError(
                f"lora.vlm_module='{spec.vlm_module}' not found on model "
                f"(available: {[n for n, _ in model.named_children()]})"
            )
        vlm_interface = getattr(model, spec.vlm_module)
    else:
        vlm_interface = _detect_vlm_interface(model)
    assert vlm_interface is not None, (
        "No VLM interface found for LoRA injection. "
        "Set lora.vlm_module explicitly in config."
    )

    # 2 + 3. Freeze backbone, inject PEFT
    for p in vlm_interface.parameters():
        p.requires_grad = False
    vlm_interface.model = get_peft_model(vlm_interface.model, lora_config)

    # 4. Freeze extras
    for module_name in spec.freeze_extra_modules:
        if hasattr(model, module_name):
            extra_module = getattr(model, module_name)
            n = sum(1 for _ in extra_module.parameters())
            for p in extra_module.parameters():
                p.requires_grad = False
            logger.info(f"Froze extra module '{module_name}' ({n} param tensors)")
        else:
            logger.warning(f"freeze_extra_modules: '{module_name}' not found, skipping")

    # 5. Logging
    if print_summary:
        logger.info("LoRA enabled on VLM backbone")
        try:
            vlm_interface.model.print_trainable_parameters()
        except Exception:
            pass
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        logger.info(
            f"Total trainable: {trainable / 1e6:.1f}M / {total / 1e6:.1f}M "
            f"({100 * trainable / max(total, 1):.2f}%)"
        )

    return model
merge_lora_checkpoint

merge_lora_checkpoint.py — Merge LoRA adapter + non-VLM weights into a full checkpoint usable by the standard eval pipeline (server_policy.py + BaseFramework.from_pretrained).

Thin CLI wrapper around the sibling load_and_merge() helper. Located inside the peft module so it can be invoked via python -m without path hacks.

Usage (from repo root, starVLA env active): python -m AlphaBrain.training.trainer_utils.peft.merge_lora_checkpoint \ --base_config configs/continual_learning/qwengr00t_continual_libero.yaml \ --lora_adapter_dir results/Checkpoints/.../task_4_id4_steps_50000_lora_adapter \ --action_model_pt results/Checkpoints/.../task_4_id4_steps_50000_action_model.pt \ --output_path results/Checkpoints/.../task_4_id4_steps_50000_pytorch_model.pt