Model › Framework¶
Source path: AlphaBrain/model/framework/
Each framework is an independent VLA model implementation. Frameworks are registered via FRAMEWORK_REGISTRY in AlphaBrain.model.tools and constructed by the build_framework(cfg) factory based on cfg.framework.name.
Factory and registry¶
framework ¶
Framework factory utilities. Automatically builds registered framework implementations based on configuration.
Each framework module (e.g., M1.py, QwenFast.py) should register itself: from AlphaBrain.model.framework.framework_registry import FRAMEWORK_REGISTRY
@FRAMEWORK_REGISTRY.register("InternVLA-M1")
def build_model_framework(config):
return InternVLA_M1(config=config)
build_framework ¶
Build a framework model from config. Args: cfg: Config object (OmegaConf / namespace) containing: cfg.framework.name: Identifier string (e.g. "InternVLA-M1") Returns: nn.Module: Instantiated framework model.
Source code in AlphaBrain/model/framework/__init__.py
Base class and config utilities¶
base_framework ¶
Base framework abstraction providing: - Pretrained loading (config + normalization stats + weights) - Action space utilities (dimension, stats, (un)normalization) - Trainable module discovery helper Note: No device placement or optimizer concerns handled here (delegated to trainer).
BaseFramework ¶
Bases: PreTrainedModel
Lightweight base class for higher-level VLA model assemblies. Subclasses are expected to: - Accept a structured config - Register components in init - Use provided helpers for action normalization handling
Initialize base nn.Module. Subclasses add components.
Source code in AlphaBrain/model/framework/base_framework.py
trainable_module_keys property ¶
Enumerate trainable submodule names up to a depth.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
max_depth | Descent depth when traversing module tree. | required |
Returns:
| Type | Description |
|---|---|
List[str] | List[str]: Module path names considered trainable. |
from_pretrained classmethod ¶
Restore a model instance from a saved checkpoint.
Workflow
- Resolve checkpoint path
- Load config + dataset normalization statistics
- Build model with loaded config
- Load state_dict strictly (reports missing/unexpected keys)
- Attach normalization stats for later un-normalization
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pretrained_checkpoint | str | Path to .pt/.safetensors file or self-contained checkpoint directory. | required |
**kwargs | Extra constructor overrides passed to subclass. | {} |
Returns:
| Name | Type | Description |
|---|---|---|
BaseFramework | None | Instantiated model (left on CPU; caller decides device). |
Raises:
| Type | Description |
|---|---|
RuntimeError | If state_dict key mismatch occurs under strict=True. |
FileNotFoundError | If underlying files are missing (surfaced earlier). |
Source code in AlphaBrain/model/framework/base_framework.py
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | |
convert_checkpoint_to_dir staticmethod ¶
Convert an old-format file checkpoint to the new self-contained directory format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
old_ckpt_path | str | Path to old .safetensors/.pt checkpoint file. | required |
output_dir | str | Output directory path. If None, creates a directory alongside the file. | None |
base_vlm_path | str | Path to Qwen base model (for saving config + processor). If None, reads from the checkpoint's config.yaml. | None |
Source code in AlphaBrain/model/framework/base_framework.py
get_action_stats classmethod ¶
Retrieve raw action normalization statistics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
unnorm_key | Optional dataset stats key. | None |
Returns:
| Name | Type | Description |
|---|---|---|
dict | Stats structure (e.g. q01, q99, mask). |
Source code in AlphaBrain/model/framework/base_framework.py
unnormalize_actions staticmethod ¶
unnormalize_actions(normalized_actions: ndarray, action_norm_stats: Dict[str, ndarray]) -> np.ndarray
Map normalized actions (≈[-1, 1]) back to original value range.
Auto-detects normalization mode via the optional 'norm_mode' key in action_norm_stats (defaults to 'q99' for backward compatibility): - 'q99' → uses q01 / q99 bounds - 'min_max' → uses min / max bounds
Steps
- Clamp values to [-1, 1]
- Threshold channel index 6 to {0,1} (binary semantic)
- Apply linear scaling for masked dimensions
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
normalized_actions | ndarray | Array shape [T, D] (or chunk length × action_dim). | required |
action_norm_stats | Dict[str, ndarray] | Dict containing stat arrays and optional 'norm_mode'. | required |
Returns:
| Type | Description |
|---|---|
ndarray | np.ndarray: Unnormalized actions (same shape as input). |
Source code in AlphaBrain/model/framework/base_framework.py
config_utils ¶
Shared configuration / utility helpers for framework components: - NamespaceWithGet: lightweight namespace behaving like a dict - OmegaConf conversion helpers - Config merging decorator for model init - Checkpoint config/statistics loading
NamespaceWithGet ¶
Bases: SimpleNamespace
get ¶
Return attribute value if present, else default (dict-like API).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key | Attribute name. | required | |
default | Fallback if attribute missing. | None |
Returns:
| Name | Type | Description |
|---|---|---|
Any | Stored value or default. |
Source code in AlphaBrain/model/framework/config_utils.py
items ¶
Iterate (key, value) pairs like dict.items().
Returns:
| Type | Description |
|---|---|
| Generator[Tuple[str, Any], None, None] |
to_dict ¶
Recursively convert nested NamespaceWithGet objects into plain dicts.
Returns:
| Name | Type | Description |
|---|---|---|
dict | Fully materialized dictionary structure. |
Source code in AlphaBrain/model/framework/config_utils.py
dict_to_namespace ¶
Create an OmegaConf config from a plain dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d | Input dictionary. | required |
Returns:
| Name | Type | Description |
|---|---|---|
OmegaConf | DictConfig instance. |
merge_param_config ¶
Decorator for init to unify config handling.
Behavior
- Extract 'config' kwarg / arg (path | dict | OmegaConf | namespace)
- Convert to OmegaConf
- Merge with explicitly passed init parameters (explicit overrides file)
- Attach merged config to self.config
- Call original init with merged config
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init | Original init function. | required |
Returns:
| Type | Description |
|---|---|
| Wrapped initializer. |
Source code in AlphaBrain/model/framework/config_utils.py
read_model_config ¶
Load global model configuration and dataset normalization statistics associated with a saved checkpoint (.pt).
Expected directory layout
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pretrained_checkpoint | Path to a .pt checkpoint file. | required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple | global_cfg (dict): Loaded config.json contents. norm_stats (dict): Dataset statistics for (de)normalization. |
Raises:
| Type | Description |
|---|---|
FileNotFoundError | If checkpoint or required JSON files are missing. |
AssertionError | If file suffix or structure invalid. |
Source code in AlphaBrain/model/framework/config_utils.py
read_mode_config ¶
Same as read_model_config (legacy duplicate kept for backward compatibility).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pretrained_checkpoint | Path to a .pt checkpoint file. | required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple | vla_cfg (dict) norm_stats (dict) |
Source code in AlphaBrain/model/framework/config_utils.py
ToyVLA¶
ToyModel ¶
ToyVLA — 极简 VLA 调试模型¶
设计目标: - 无 VLM 依赖, 无需 Qwen / LLM,秒级加载 - 接口与 QwenOFT 完全一致 (forward / predict_action 接受同样的 examples List[dict]) - 能在几百步内 overfit 小样本 → 验证训练管线是否正确
验证方法
- 把 N 个固定样本喂进去,train 几百步
- 如果 action_loss 接近 0、eval MSE 接近 0 → 管线正确
- 否则说明 data → forward → loss → backward 链路有 bug
Interface (与 QwenOFT 相同): examples: List[dict] - "image" : List[PIL.Image] (multi-view, 各尺寸均可) - "lang" : str - "action" : np.ndarray shape (T, action_dim)
forward(examples) → {"action_loss": scalar_tensor} predict_action(examples) → {"normalized_actions": np.ndarray (B, chunk_len, action_dim)}
TinyImageEncoder ¶
Bases: Module
把任意尺寸 PIL Image 压成 (img_feat_dim,) 向量,纯卷积,参数量 ~10K
Source code in AlphaBrain/model/framework/ToyModel.py
TinyTextEncoder ¶
Bases: Module
Source code in AlphaBrain/model/framework/ToyModel.py
forward ¶
texts: List[str] → (B, text_feat_dim)
Source code in AlphaBrain/model/framework/ToyModel.py
ToyVLA ¶
Bases: PreTrainedModel
极简 VLA 调试模型。 - 用 TinyImageEncoder + TinyTextEncoder 代替 Qwen VLM - 用小 MLP 做动作回归 - 整体 < 200K 参数,单卡几秒即可 overfit 小 batch
Source code in AlphaBrain/model/framework/ToyModel.py
forward ¶
Returns:
| Type | Description |
|---|---|
dict | {"action_loss": scalar tensor} |
Source code in AlphaBrain/model/framework/ToyModel.py
predict_action ¶
Returns:
| Type | Description |
|---|---|
dict | {"normalized_actions": np.ndarray (B, chunk_len, action_dim)} |
Source code in AlphaBrain/model/framework/ToyModel.py
ACT¶
ACT ¶
ACT — Action Chunking Transformers (standalone implementation)¶
Reference
Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware Zhao et al., RSS 2023
Architecture
- ResNet18 visual encoder (per camera view)
- CVAE encoder: (robot_state, action_chunk) → z (training only; z=0 at inference)
- Transformer encoder: [z_token, img_tokens, state_token] → memory
- Transformer decoder: query_embed → action_chunk
Interface (same as QwenOFT / ToyVLA): examples: List[dict] - "image" : List[PIL.Image] (multi-view, any size) - "lang" : str (ignored during action prediction, kept for API compat) - "action" : np.ndarray shape (T, action_dim) - "state" : np.ndarray shape (T_state, state_dim) [optional]
forward(examples) → {"action_loss": tensor} predict_action(examples) → {"normalized_actions": np.ndarray (B, chunk_len, action_dim)}
CVAEEncoder ¶
CVAEEncoder(state_dim: int, action_dim: int, hidden_dim: int, latent_dim: int, num_heads: int = 4, num_layers: int = 2)
Bases: Module
Encodes (robot_state, action_chunk) → (mu, log_var). Inputs are projected to hidden_dim then fused through a small Transformer encoder.
Source code in AlphaBrain/model/framework/ACT.py
forward ¶
state: (B, state_dim) action_chunk: (B, chunk_len, action_dim) Returns: mu, log_var each (B, latent_dim)
Source code in AlphaBrain/model/framework/ACT.py
ACTModel ¶
Bases: PreTrainedModel
Standalone ACT (Action Chunking Transformers) model.
Key design choices vs. paper: - Use ResNet18 (torchvision) instead of ResNet18 with backbone unfreezing - Replace FiLM conditioning with simple token concatenation - Use PyTorch native Transformer encoder / decoder
Source code in AlphaBrain/model/framework/ACT.py
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | |
forward ¶
Returns:
| Type | Description |
|---|---|
dict | dict with keys: action_loss (scalar) kl_loss (scalar) |
Source code in AlphaBrain/model/framework/ACT.py
predict_action ¶
predict_action(examples: List[dict] = None, batch_images: List[List] = None, instructions: List[str] = None, states: ndarray = None, **kwargs) -> dict
Accepts two input formats:
-
examples format (train / debug): examples = [{"image": [PIL,...], "lang": str, "state": np.ndarray}, ...]
-
Flat format (from websocket server / M1Inference): batch_images = [[img0, img1], ...] (B × n_views, np.ndarray or PIL) instructions = ["task description", ...] states = np.ndarray (B, T, state_dim) or (B, state_dim)
Returns:
| Name | Type | Description |
|---|---|---|
dict | dict | normalized_actions: np.ndarray (B, chunk_len, action_dim) |
Source code in AlphaBrain/model/framework/ACT.py
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 | |
CosmosPolicy¶
CosmosPolicy ¶
CosmosPolicy Framework
A video diffusion model (Cosmos Predict2 2B DiT) fine-tuned for robot policy prediction. Unlike VLM-based frameworks (QwenOFT, QwenGR00T), this uses latent-space diffusion: - WAN 2.1 VAE encodes images to latent space (frozen) - MiniTrainDIT backbone denoises latent sequences (trainable) - Actions/proprio/value are injected into latent frames - T5 text embeddings provide language conditioning (precomputed)
Latent frame layout (LIBERO, state_t=9): [blank, curr_proprio, curr_wrist, curr_primary, action, future_proprio, future_wrist, future_primary, value]
CosmosPolicy ¶
Bases: BaseFramework
Cosmos-Policy: latent-space video diffusion for robot action prediction.
Training: VAE encode → inject action/proprio → diffusion loss on latent sequence Inference: multi-step denoising → extract action from latent frame
Source code in AlphaBrain/model/framework/CosmosPolicy.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | |
forward ¶
Training forward pass: diffusion denoising loss on full 9-frame latent sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples | batched dict from DataLoader (each value is a stacked tensor), or list of dicts (legacy). | required |
Returns:
| Type | Description |
|---|---|
| {"action_loss": total_loss} |
Source code in AlphaBrain/model/framework/CosmosPolicy.py
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 | |
predict_action ¶
Inference: multi-step denoising to predict action chunk.
Matches original cosmos-policy get_action() flow: 1. Build full 33-frame video (with placeholders for prediction frames) 2. VAE encode → 9 latent frames 3. Inject normalized proprio into frame 1 4. Save condition frames (0-3), replace prediction frames (4-8) with noise 5. Multi-step denoising 6. Extract action from denoised latent at action_idx
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples | list of dicts with image, wrist_image, lang, proprio | None | |
batch_images | alternative — list of PIL images | None | |
instructions | alternative — list of strings | None |
Returns:
| Type | Description |
|---|---|
| {"normalized_actions": np.ndarray of shape (B, chunk_size, action_dim)} |
Source code in AlphaBrain/model/framework/CosmosPolicy.py
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 | |
set_dataset_stats ¶
Store dataset statistics for proprio normalization during inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset_stats | dict | dict with keys 'proprio_min', 'proprio_max' (np.ndarray). | required |
Source code in AlphaBrain/model/framework/CosmosPolicy.py
NeuroVLA¶
NeuroVLA ¶
NeuroVLA ¶
NeuroVLA(config: Optional[dict] = None, norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]] = None, **kwargs)
Bases: BaseFramework
NeuroVLA: Vision-Language-Action model for robotic manipulation.
This model combines a vision-language model (Qwen-VL) with action prediction to generate robot actions from visual observations and language instructions.
Source code in AlphaBrain/model/framework/NeuroVLA.py
forward ¶
Run a forward pass through the VLM, returning loss for training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples | List[dict] | List of training examples, each containing: - "image": Input images - "lang": Language instructions - "action": Ground truth actions [B, T, 7] - "state": Robot states [B, T, 8] - "solution" (optional): Chain-of-thought solutions | None |
Returns:
| Type | Description |
|---|---|
Tuple | Dictionary containing action_loss |
Source code in AlphaBrain/model/framework/NeuroVLA.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |
predict_action ¶
predict_action(batch_images: Union[Image, List[Image]], instructions: List[str], states: Optional[List[Sequence[float]]] = None, solutions: Union[Dict, List[Dict]] = None, unnorm_key: Optional[str] = None, cfg_scale: float = 1.5, use_ddim: bool = False, num_ddim_steps: int = 5, **kwargs: str) -> np.ndarray
Predict action from images and instructions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_images | Union[Image, List[Image]] | Input images (PIL Image or list of PIL Images) | required |
instructions | List[str] | Task instructions (list of strings) | required |
states | Optional[List[Sequence[float]]] | Robot states history [B, T, 8], where last dim is [x,y,z,roll,pitch,yaw,gripper,pad] | None |
solutions | Union[Dict, List[Dict]] | Optional solution dict for chain-of-thought | None |
unnorm_key | Optional[str] | Key for unnormalization (if using norm_stats) | None |
cfg_scale | float | Classifier-free guidance scale (>1.0 enables CFG) | 1.5 |
use_ddim | bool | Whether to use DDIM sampling | False |
num_ddim_steps | int | Number of DDIM steps | 5 |
Returns:
| Type | Description |
|---|---|
ndarray | Dictionary containing "normalized_actions" [B, T, 7] |
Source code in AlphaBrain/model/framework/NeuroVLA.py
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 | |
build_model_framework ¶
PaliGemma family¶
PaliGemmaOFT ¶
PaliGemma-OFT Framework
Uses PaliGemma (SigLIP + Gemma 2B) as VLM backbone with action special token for continuous action prediction via L1 regression. Mirrors LlamaOFT / QwenOFT architecture but with PaliGemma backbone.
PaliGemma_OFT ¶
Bases: BaseFramework
PaliGemma + action token OFT framework. Predicts continuous actions via L1 regression on action token hidden states.
Source code in AlphaBrain/model/framework/PaliGemmaOFT.py
forward ¶
Training forward: L1 regression on action tokens.
Source code in AlphaBrain/model/framework/PaliGemmaOFT.py
predict_action ¶
predict_action(batch_images: List = None, instructions: List[str] = None, examples: List[dict] = None, **kwargs) -> np.ndarray
Inference: predict normalized actions.
Source code in AlphaBrain/model/framework/PaliGemmaOFT.py
PaliGemmaPi0 ¶
PaliGemmaOFT Framework
Integrates the π₀/π₀.₅ flow matching architecture into VLA-Engine. Key innovation: the VLM backbone is swappable — you can use PaliGemma (original), Qwen2.5-VL, Llama 3.2 Vision, or any future VLM backend.
Architecture
VLM (any) → prefix embedding → [KV cache] → Action Expert (Gemma) + Flow Matching → actions
Components
- VLM interface: reuses AlphaBrain's existing get_vlm_model() factory
- Action Expert: independent Gemma transformer (from openpi)
- Flow Matching Head: multi-step denoising action generation
Training: flow matching loss (MSE between predicted and target velocity fields) Inference: iterative denoising from Gaussian noise (default 10 steps)
PaliGemma_OFT ¶
Bases: BaseFramework
Pi0/Pi0.5 framework with swappable VLM backbone.
Config structure
framework: name: PaliGemmaOFT pi05: true # true for π₀.₅, false for π₀ paligemma: # or qwenvl/llamavl — uses get_vlm_model() base_vlm: google/paligemma-3b-pt-224 action_expert: width: 1024 depth: 18 ... action_model: action_dim: 7 action_horizon: 50 num_inference_steps: 10
Source code in AlphaBrain/model/framework/PaliGemmaPi0.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | |
forward ¶
Training forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples | List[dict] | list of dicts with keys: image, lang, action, (state) | None |
Returns:
| Type | Description |
|---|---|
Tuple | (loss, metrics_dict) |
Source code in AlphaBrain/model/framework/PaliGemmaPi0.py
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 | |
predict_action ¶
predict_action(batch_images: List = None, instructions: List[str] = None, examples: List[dict] = None, unnorm_key=None, **kwargs)
Inference: predict actions via multi-step denoising.
Returns:
| Type | Description |
|---|---|
| np.ndarray: [B, action_horizon, action_dim] unnormalized actions |
Source code in AlphaBrain/model/framework/PaliGemmaPi0.py
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 | |
PaliGemmaPi05 ¶
PaliGemmaOFT Framework
Integrates the π₀/π₀.₅ flow matching architecture into VLA-Engine. Key innovation: the VLM backbone is swappable — you can use PaliGemma (original), Qwen2.5-VL, Llama 3.2 Vision, or any future VLM backend.
Architecture
VLM (any) → prefix embedding → [KV cache] → Action Expert (Gemma) + Flow Matching → actions
Components
- VLM interface: reuses VLAE's existing get_vlm_model() factory
- Action Expert: independent Gemma transformer (from openpi)
- Flow Matching Head: multi-step denoising action generation
Training: flow matching loss (MSE between predicted and target velocity fields) Inference: iterative denoising from Gaussian noise (default 10 steps)
PaliGemma_Pi05 ¶
Bases: BaseFramework
Pi0/Pi0.5 framework with swappable VLM backbone.
Config structure
framework: name: PaliGemmaOFT pi05: true # true for π₀.₅, false for π₀ paligemma: # or qwenvl/llamavl — uses get_vlm_model() base_vlm: google/paligemma-3b-pt-224 action_expert: width: 1024 depth: 18 ... action_model: action_dim: 7 action_horizon: 50 num_inference_steps: 10
Source code in AlphaBrain/model/framework/PaliGemmaPi05.py
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | |
forward ¶
Training forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples | List[dict] | list of dicts with keys: image, lang, action, (state) | None |
Returns:
| Type | Description |
|---|---|
Tuple | (loss, metrics_dict) |
Source code in AlphaBrain/model/framework/PaliGemmaPi05.py
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 | |
predict_action ¶
predict_action(batch_images: List = None, instructions: List[str] = None, examples: List[dict] = None, unnorm_key=None, **kwargs)
Inference: predict actions via multi-step denoising.
Returns:
| Type | Description |
|---|---|
| np.ndarray: [B, action_horizon, action_dim] unnormalized actions |
Source code in AlphaBrain/model/framework/PaliGemmaPi05.py
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 | |
Llama OFT¶
LlamaOFT ¶
Llama-OFT Framework
Uses Llama 3.2 Vision as backbone with action special token for continuous action prediction. Mirrors QwenOFT but swaps Qwen for Llama 3.2 Vision.
Llama_OFT ¶
Bases: BaseFramework
Llama 3.2 Vision + action token OFT framework. Predicts continuous actions via L1 regression on action token hidden states.
Source code in AlphaBrain/model/framework/LlamaOFT.py
Qwen family¶
QwenOFT ¶
Qwen-OFT Framework
A lightweight implementation that uses an action special token to parallelly predict continuous actions conditioned on multi-view images plus a language instruction (shares parameters with the VLM). Inspired by OpenVLA-OFT Key Points: - Qwen2.5 vision-language backbone - Injects an action special token into the VLM - Continuous action prediction via L1 regression over the action special token hidden states
How to add special tokens to Qwen2.5:
download our model checkpoint with special tokens added: https://huggingface.co/AlphaBrain/Qwen2.5-VL-3B-Instruct-Action or /AlphaBrain/model/modules/vlm/tools/add_qwen_special_tokens/README.md (adpat a little code)
Qwenvl_OFT ¶
Bases: BaseFramework
Multimodal vision-language-action model.
Components
- Qwen2.5 VL interface for fused language/vision token embeddings
- Layer-wise QFormer for multi-layer feature aggregation
- DINO encoder for dense multi-view spatial tokens
- DiT diffusion head for future action sequence modeling
Focus: Predict future continuous actions conditioned on images + instruction.
Construct all submodules and cache key configuration values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | Optional[dict] | Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. | None |
**kwargs | Reserved for future overrides (unused). | {} |
Source code in AlphaBrain/model/framework/QwenOFT.py
forward ¶
训练前向:直接回归未来动作(无扩散)。
Flow
- Build QwenVL inputs (images + instruction tokens)
- Extract hidden states from configured layer range
- Predict action and compute L1 loss
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples | List[dict] | List[dict], each dict requires: - image: List[PIL.Image] (multi-view) - lang: str instruction - action: np.ndarray or list shaped [T, action_dim] | None |
**kwargs | Reserved. | {} |
Returns:
| Name | Type | Description |
|---|---|---|
dict | Tuple | action_loss (torch.Tensor): Scalar diffusion noise prediction loss. |
Source code in AlphaBrain/model/framework/QwenOFT.py
predict_action ¶
predict_action(batch_images: List = None, instructions: List[str] = None, examples: List[dict] = None, **kwargs) -> np.ndarray
推理:单次前向直接回归未来动作(无扩散采样)。
Accepts two input formats
- Flat format (from M1Inference websocket client): batch_images + instructions
- Legacy format: examples (list of dicts with "image" and "lang" keys)
Steps
- Resize images to training resolution (if specified)
- Encode with QwenVL (hidden states retained)
- Return normalized action trajectory
Returns:
| Name | Type | Description |
|---|---|---|
dict | ndarray | normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. |
Source code in AlphaBrain/model/framework/QwenOFT.py
get_action_queries ¶
Extract action_queries from frozen VLM without going through the action head.
Returns:
| Name | Type | Description |
|---|---|---|
action_queries | Tensor | (B, chunk_len, H) tensor on model device |
Source code in AlphaBrain/model/framework/QwenOFT.py
get_vla_action ¶
Get both action_queries and VLA base action predictions (frozen).
Returns:
| Name | Type | Description |
|---|---|---|
action_queries | (B, chunk_len, H) tensor | |
vla_actions | (B, chunk_len, action_dim) tensor (normalized) |
Source code in AlphaBrain/model/framework/QwenOFT.py
QwenPI ¶
Qwen-GROOT Framework A lightweight implementation that Qwen2.5-vl + Flow-matching head to directly predict continuous actions Flow-matching header is copyright from GR00T N1.5, but a sample MoE inspired by PI_0
Extended (2026-04): World model backbone support (V-JEPA, Cosmos, Wan). When a world model VLM is used, forward_all_layers() extracts per-backbone- block features so each PI action head layer cross-attends to a DIFFERENT backbone layer (true layerwise cross-attention, no replication).
Qwen_PI ¶
Bases: BaseFramework
Multimodal vision-language-action model.
Components
- Qwen2.5 VL interface for fused language/vision token embeddings
- Layer-wise cross DiT diffusion head
World model mode
When the VLM is a world model (Cosmos, Wan), per-DiT-block features are extracted via forward_all_layers() so each PI action head layer cross-attends to a DIFFERENT backbone layer. framework.qwenvl.num_vl_layers must match the backbone block count (28 for Cosmos, 30 for Wan).
Focus: Predict future continuous actions conditioned on images + instruction.
Source code in AlphaBrain/model/framework/QwenPI.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | |
forward ¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples | List[dict] | List[dict], each dict requires: - image: List[PIL.Image] (multi-view) - lang: str instruction - action: np.ndarray or list shaped [T, action_dim] | None |
Returns: dict: action_loss (torch.Tensor): Scalar diffusion noise prediction loss.
Source code in AlphaBrain/model/framework/QwenPI.py
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 | |
predict_action ¶
predict_action(examples: List[dict] = None, batch_images: List = None, instructions: List[str] = None, states=None, **kwargs) -> np.ndarray
Inference: single forward pass to regress future actions via flow-matching sampling through the layerwise DiT.
Supports two input formats
- examples: List[dict] with keys "image", "lang", "state" (legacy)
- batch_images + instructions: direct arguments
Returns:
| Name | Type | Description |
|---|---|---|
dict | ndarray | normalized_actions (np.ndarray): Shape [B, T, action_dim]. |
Source code in AlphaBrain/model/framework/QwenPI.py
QwenGR00T ¶
Qwen-GR00T Framework A lightweight implementation that Qwen-VL + Flow-matching head to directly predict continuous actions Flow-matching header is copyright from GR00T N1.5,
Qwen_GR00T ¶
Bases: BaseFramework
Multimodal vision-language-action model.
Components
- Qwen2.5 VL interface for fused language/vision token embeddings
- Layer-wise QFormer for multi-layer feature aggregation
- DINO encoder for dense multi-view spatial tokens
- DiT diffusion head for future action sequence modeling
Focus: Predict future continuous actions conditioned on images + instruction.
Construct all submodules and cache key configuration values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | Optional[dict] | Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. | None |
**kwargs | Reserved for future overrides (unused). | {} |
Source code in AlphaBrain/model/framework/QwenGR00T.py
forward ¶
Run a full training forward pass, with video loss when next_image is available.
When next_image is available, performs a SINGLE DiT forward that simultaneously yields action visual tokens and the next-frame video prediction loss. Both share the same backward graph so the DiT backbone receives gradients from both losses without a redundant forward pass.
During inference (no next_image): the standard encode path is used and no video loss is computed.
Source code in AlphaBrain/model/framework/QwenGR00T.py
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | |
predict_action ¶
predict_action(examples: List[dict] = None, batch_images: List = None, instructions: List[str] = None, states=None, return_predicted_frame: bool = False, **kwargs) -> np.ndarray
Steps
- Resize images to training resolution (if specified)
- Encode with QwenVL (hidden states retained)
- Return normalized action trajectory
Supports two input formats
- examples: List[dict] with keys "image", "lang", "state" (legacy format)
- batch_images + instructions: direct arguments (consistent with NeuroVLA/QwenOFT)
Returns:
| Name | Type | Description |
|---|---|---|
dict | ndarray | normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. |
Source code in AlphaBrain/model/framework/QwenGR00T.py
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 | |
predict_future_frame ¶
predict_future_frame(batch_images: List = None, instructions: List[str] = None, num_steps: int = 5, sigma_min: float = 4.0, sigma_max: float = 80.0) -> np.ndarray
Predict future frame using DiT denoising.
Returns:
| Name | Type | Description |
|---|---|---|
future_frames | ndarray | np.ndarray [B, H, W, 3] uint8 predicted future frames |
Source code in AlphaBrain/model/framework/QwenGR00T.py
World Model VLA¶
WorldModelVLA ¶
WorldModelVLA Framework (Phase 2b-A)
Clean-room rename/clone of Qwen_GR00T for world-model-backbone VLA: - Uses WorldModelVLMInterface (not a Qwen VLM) as visual encoder - Reads config.framework.world_model. (not config.framework.qwenvl.) - Attribute renamed: self.qwen_vl_interface -> self.world_model_encoder - Framework-local wrapper: prepare_inputs (still calls interface.build_vlm_inputs under the hood; interface method rename deferred to Phase 3+)
QwenGR00T.py remains the canonical path for Qwen VLMs and is NOT modified.
WorldModelVLA ¶
Bases: BaseFramework
World-Model-backbone vision-language-action model.
Components
- WorldModelVLMInterface: Cosmos-Predict2 / V-JEPA2 / WAN2 visual encoder with lightweight text encoder and cross-attention fusion
- FlowmatchingActionHead (GR00T-N1.5) for future action sequence modeling
Focus: Predict future continuous actions conditioned on images + instruction, optionally with next-frame video loss for joint WM + action training.
Construct all submodules and cache key configuration values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | Optional[dict] | Hierarchical configuration (OmegaConf/dict) containing framework.world_model. (encoder) and framework.action_model. (FlowMatching head) sections. | None |
**kwargs | Reserved for future overrides (unused). | {} |
Source code in AlphaBrain/model/framework/WorldModelVLA.py
prepare_inputs ¶
Framework-local wrapper around the encoder's build_vlm_inputs.
Interface method rename is deferred to Phase 3+; for now the underlying call still hits self.world_model_encoder.build_vlm_inputs(...).
Source code in AlphaBrain/model/framework/WorldModelVLA.py
forward ¶
Run a full training forward pass, with video loss when next_image is available.
When next_image is available, performs a SINGLE DiT forward that simultaneously yields action visual tokens and the next-frame video prediction loss. Both share the same backward graph so the DiT backbone receives gradients from both losses without a redundant forward pass.
During inference (no next_image): the standard encode path is used and no video loss is computed.
Source code in AlphaBrain/model/framework/WorldModelVLA.py
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 | |
predict_action ¶
predict_action(examples: List[dict] = None, batch_images: List = None, instructions: List[str] = None, states=None, return_predicted_frame: bool = False, **kwargs) -> np.ndarray
Steps
- Resize images to training resolution (if specified)
- Encode with world model (hidden states retained)
- Run FlowMatching action head
- Return normalized action trajectory
Supports two input formats
- examples: List[dict] with keys "image", "lang", "state" (legacy format)
- batch_images + instructions: direct arguments (consistent with NeuroVLA/QwenOFT)
Returns:
| Name | Type | Description |
|---|---|---|
dict | ndarray | normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. |
Source code in AlphaBrain/model/framework/WorldModelVLA.py
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 | |
predict_future_frame ¶
predict_future_frame(batch_images: List = None, instructions: List[str] = None, num_steps: int = 5, sigma_min: float = 4.0, sigma_max: float = 80.0) -> np.ndarray
Predict future frame using DiT denoising.
Returns:
| Name | Type | Description |
|---|---|---|
future_frames | ndarray | np.ndarray [B, H, W, 3] uint8 predicted future frames |