Skip to content

Model › Framework

Source path: AlphaBrain/model/framework/

Each framework is an independent VLA model implementation. Frameworks are registered via FRAMEWORK_REGISTRY in AlphaBrain.model.tools and constructed by the build_framework(cfg) factory based on cfg.framework.name.


Factory and registry

framework

Framework factory utilities. Automatically builds registered framework implementations based on configuration.

Each framework module (e.g., M1.py, QwenFast.py) should register itself: from AlphaBrain.model.framework.framework_registry import FRAMEWORK_REGISTRY

@FRAMEWORK_REGISTRY.register("InternVLA-M1")
def build_model_framework(config):
    return InternVLA_M1(config=config)

build_framework

build_framework(cfg)

Build a framework model from config. Args: cfg: Config object (OmegaConf / namespace) containing: cfg.framework.name: Identifier string (e.g. "InternVLA-M1") Returns: nn.Module: Instantiated framework model.

Source code in AlphaBrain/model/framework/__init__.py
def build_framework(cfg):
    """
    Build a framework model from config.
    Args:
        cfg: Config object (OmegaConf / namespace) containing:
             cfg.framework.name: Identifier string (e.g. "InternVLA-M1")
    Returns:
        nn.Module: Instantiated framework model.
    """

    if not hasattr(cfg.framework, "name"): 
        cfg.framework.name = cfg.framework.framework_py  # Backward compatibility for legacy config yaml

    if cfg.framework.name == "ToyVLA":
        from AlphaBrain.model.framework.ToyModel import ToyVLA
        return ToyVLA(config=cfg)
    elif cfg.framework.name == "QwenOFT":
        from AlphaBrain.model.framework.QwenOFT import Qwenvl_OFT
        return Qwenvl_OFT(cfg)
    elif cfg.framework.name == "QwenFast":
        from AlphaBrain.model.framework.QwenFast import Qwenvl_Fast
        return Qwenvl_Fast(cfg)
    elif cfg.framework.name == "NeuroVLA":
        from AlphaBrain.model.framework.NeuroVLA import NeuroVLA
        return NeuroVLA(cfg)
    elif cfg.framework.name == "QwenGR00T":
        from AlphaBrain.model.framework.QwenGR00T import Qwen_GR00T
        return Qwen_GR00T(cfg)
    elif cfg.framework.name == "ACT":
        from AlphaBrain.model.framework.ACT import ACTModel
        return ACTModel(config=cfg)
    elif cfg.framework.name == "CosmosPolicy":
        from AlphaBrain.model.framework.CosmosPolicy import CosmosPolicy
        return CosmosPolicy(config=cfg)
    elif cfg.framework.name == "PaliGemmaOFT":
        from AlphaBrain.model.framework.PaliGemmaOFT import PaliGemma_OFT
        return PaliGemma_OFT(cfg)
    elif cfg.framework.name in ("PaliGemmaPi05", "PaliGemmaPi0"):
        # PaliGemmaPi05 = π₀.₅ architecture (adaRMS timestep conditioning)
        # "PaliGemmaPi0" kept as backward-compatible alias for historical checkpoints
        from AlphaBrain.model.framework.PaliGemmaPi05 import PaliGemma_Pi05
        return PaliGemma_Pi05(cfg)
    elif cfg.framework.name == "LlamaOFT":
        from AlphaBrain.model.framework.LlamaOFT import Llama_OFT
        return Llama_OFT(cfg)

    # auto detect from registry
    framework_id = cfg.framework.name
    if framework_id not in FRAMEWORK_REGISTRY._registry:
        raise NotImplementedError(f"Framework {cfg.framework.name} is not implemented. Plz, python yourframework_py to specify framework module.")

    MODEL_CLASS = FRAMEWORK_REGISTRY[framework_id]
    return MODEL_CLASS(cfg)

tools

FRAMEWORK_REGISTRY module-attribute

FRAMEWORK_REGISTRY = Registry('frameworks')

Base class and config utilities

base_framework

Base framework abstraction providing: - Pretrained loading (config + normalization stats + weights) - Action space utilities (dimension, stats, (un)normalization) - Trainable module discovery helper Note: No device placement or optimizer concerns handled here (delegated to trainer).

BaseFramework

BaseFramework(hf_config=PretrainedConfig())

Bases: PreTrainedModel

Lightweight base class for higher-level VLA model assemblies. Subclasses are expected to: - Accept a structured config - Register components in init - Use provided helpers for action normalization handling

Initialize base nn.Module. Subclasses add components.

Source code in AlphaBrain/model/framework/base_framework.py
def __init__(
    self,
    hf_config = PretrainedConfig()
) -> None:
    """
    Initialize base nn.Module. Subclasses add components.
    """

    super().__init__(hf_config)
trainable_module_keys property
trainable_module_keys: List[str]

Enumerate trainable submodule names up to a depth.

Parameters:

Name Type Description Default
max_depth

Descent depth when traversing module tree.

required

Returns:

Type Description
List[str]

List[str]: Module path names considered trainable.

from_pretrained classmethod
from_pretrained(pretrained_checkpoint: str, **kwargs) -> None

Restore a model instance from a saved checkpoint.

Workflow
  1. Resolve checkpoint path
  2. Load config + dataset normalization statistics
  3. Build model with loaded config
  4. Load state_dict strictly (reports missing/unexpected keys)
  5. Attach normalization stats for later un-normalization

Parameters:

Name Type Description Default
pretrained_checkpoint str

Path to .pt/.safetensors file or self-contained checkpoint directory.

required
**kwargs

Extra constructor overrides passed to subclass.

{}

Returns:

Name Type Description
BaseFramework None

Instantiated model (left on CPU; caller decides device).

Raises:

Type Description
RuntimeError

If state_dict key mismatch occurs under strict=True.

FileNotFoundError

If underlying files are missing (surfaced earlier).

Source code in AlphaBrain/model/framework/base_framework.py
@classmethod
def from_pretrained(
    cls,
    pretrained_checkpoint: str,
    **kwargs,
) -> None:
    """
    Restore a model instance from a saved checkpoint.

    Workflow:
        1. Resolve checkpoint path
        2. Load config + dataset normalization statistics
        3. Build model with loaded config
        4. Load state_dict strictly (reports missing/unexpected keys)
        5. Attach normalization stats for later un-normalization

    Args:
        pretrained_checkpoint: Path to .pt/.safetensors file or self-contained checkpoint directory.
        **kwargs: Extra constructor overrides passed to subclass.

    Returns:
        BaseFramework: Instantiated model (left on CPU; caller decides device).

    Raises:
        RuntimeError: If state_dict key mismatch occurs under strict=True.
        FileNotFoundError: If underlying files are missing (surfaced earlier).
    """
    pretrained_checkpoint = Path(pretrained_checkpoint)

    # lpt0309: 支持自包含目录格式checkpoint(单一路径推断,无需base_vlm)
    if pretrained_checkpoint.is_dir():
        logger.info(f"[lpt0309] Loading from self-contained checkpoint directory: {pretrained_checkpoint}")
        model_config, norm_stats = read_mode_config(pretrained_checkpoint)  # lpt0309: 从目录读取config和norm_stats

        config = dict_to_namespace(model_config)
        config.trainer.pretrained_checkpoint = None

        # 单次加载优化 - 如果checkpoint目录包含vlm_pretrained/(兼容旧名qwen_pretrained/),
        # 直接从中加载tokenizer/config,用meta device创建模型骨架,由后续load_state_dict一次性加载所有权重
        vlm_pretrained_dir = pretrained_checkpoint / "vlm_pretrained"
        legacy_dir = pretrained_checkpoint / "qwen_pretrained"
        if not (vlm_pretrained_dir.is_dir() and any(vlm_pretrained_dir.iterdir())):
            # 兼容旧格式 qwen_pretrained/
            if legacy_dir.is_dir() and any(legacy_dir.iterdir()):
                vlm_pretrained_dir = legacy_dir

        if vlm_pretrained_dir.is_dir() and any(vlm_pretrained_dir.iterdir()):
            logger.info(f"Found {vlm_pretrained_dir.name}/ in checkpoint, using single-read loading")
            cfg_key = _detect_vlm_cfg_key(config.framework)
            if cfg_key is not None:
                vlm_block = getattr(config.framework, cfg_key)
                original_base_vlm = getattr(vlm_block, 'base_vlm', "") if not hasattr(vlm_block, 'get') else vlm_block.get('base_vlm', "")
                vlm_block.vlm_type = original_base_vlm
                vlm_block.base_vlm = str(vlm_pretrained_dir)
                vlm_block._meta_device_init = True
            else:
                logger.warning("vlm_pretrained/ found but no VLM config key detected, skipping single-read optimization")
        else:
            logger.warning(f"No vlm_pretrained/ found (or empty), falling back to two-read loading from original base_vlm")

        FrameworkModel = build_framework(cfg=config)
        FrameworkModel.norm_stats = norm_stats

        # lpt0309: 从目录中找到权重文件
        weights_path = pretrained_checkpoint / "model.safetensors"
        if not weights_path.exists():
            weights_path = pretrained_checkpoint / "pytorch_model.pt"
        assert weights_path.exists(), f"[lpt0309] No weights file found in {pretrained_checkpoint}"

        if weights_path.suffix == ".safetensors":
            from safetensors.torch import load_file
            model_state_dict = load_file(str(weights_path))
        else:
            model_state_dict = torch.load(weights_path, map_location="cpu")

        logger.info(f"[lpt0309] Loading weights from {weights_path}")

        # Key remapping: old checkpoints use 'vlm.' prefix, new model uses 'qwen_vl_interface.'
        remapped = {}
        for k, v in model_state_dict.items():
            new_k = k.replace('vlm.', 'qwen_vl_interface.', 1) if k.startswith('vlm.') else k
            remapped[new_k] = v
        if len(remapped) != len(model_state_dict):
            logger.warning(f"Key remapping changed key count: {len(model_state_dict)} -> {len(remapped)}")
        else:
            n_remapped = sum(1 for k in model_state_dict if k.startswith('vlm.'))
            if n_remapped > 0:
                logger.info(f"Remapped {n_remapped} keys from vlm.* to qwen_vl_interface.*")
        model_state_dict = remapped

        model_keys = set(FrameworkModel.state_dict().keys())
        checkpoint_keys = set(model_state_dict.keys())
        # Try strict first; fall back to non-strict if only non-critical keys mismatch
        try:
            FrameworkModel.load_state_dict(model_state_dict, strict=True)
        except RuntimeError as e:
            common_keys = model_keys.intersection(checkpoint_keys)
            missing_keys = model_keys - common_keys
            unexpected_keys = checkpoint_keys - common_keys
            if missing_keys:
                logger.warning(f"Missing keys in state_dict ({len(missing_keys)}): {missing_keys}")
            if unexpected_keys:
                logger.warning(f"Unexpected keys in state_dict ({len(unexpected_keys)}): {unexpected_keys}")
            # Fall back to non-strict loading for cross-framework weight loading (e.g. openpi → AlphaBrain)
            logger.warning(f"Strict loading failed, falling back to non-strict (missing={len(missing_keys)}, unexpected={len(unexpected_keys)})")
            FrameworkModel.load_state_dict(model_state_dict, strict=False)

        logger.info(
            "[lpt0324] Successfully loaded model from self-contained checkpoint "
            "with legacy two-stage loading"
        )
        return FrameworkModel

    # origin0309: 原始文件格式加载(需要base_vlm路径,存在冗余权重读取)
    else:
        model_config, norm_stats = read_mode_config(pretrained_checkpoint)  # read config and norm_stats

        config = dict_to_namespace(model_config)
        model_config = config
        model_config.trainer.pretrained_checkpoint = None
        # FrameworkModel = cls(config=model_config, **kwargs) # TODO find cls by config
        FrameworkModel = build_framework(cfg=model_config)
        # set for action un-norm
        FrameworkModel.norm_stats = norm_stats
        # Load from Checkpoint (Custom --> should load both *projector* and *llm* weights)
        if pretrained_checkpoint.suffix == ".safetensors":
            from safetensors.torch import load_file
            # TODO pretrained_checkpoint 这里先转成了path后面又用str, 存在冗余
            model_state_dict = load_file(str(pretrained_checkpoint))
        else:
            model_state_dict = torch.load(pretrained_checkpoint, map_location="cpu")
        # logger.info(f"Loading model weights from `{pretrained_checkpoint}`")
        model_keys = set(FrameworkModel.state_dict().keys())
        checkpoint_keys = set(model_state_dict.keys())  # TODO 为什么会存在重复?
        try:
            FrameworkModel.load_state_dict(model_state_dict, strict=True)
        except RuntimeError as e:
            # must keep all keys matched
            common_keys = model_keys.intersection(checkpoint_keys)
            missing_keys = model_keys - common_keys
            unexpected_keys = checkpoint_keys - common_keys
            if missing_keys:
                logger.warning(f"Missing keys in state_dict: {missing_keys}")
            if unexpected_keys:
                logger.warning(f"Unexpected keys in state_dict: {unexpected_keys}")

            raise e

        # **ensure model is on GPU**
        FrameworkModel = FrameworkModel
        return FrameworkModel
convert_checkpoint_to_dir staticmethod
convert_checkpoint_to_dir(old_ckpt_path: str, output_dir: str = None, base_vlm_path: str = None)

Convert an old-format file checkpoint to the new self-contained directory format.

Parameters:

Name Type Description Default
old_ckpt_path str

Path to old .safetensors/.pt checkpoint file.

required
output_dir str

Output directory path. If None, creates a directory alongside the file.

None
base_vlm_path str

Path to Qwen base model (for saving config + processor). If None, reads from the checkpoint's config.yaml.

None
Source code in AlphaBrain/model/framework/base_framework.py
@staticmethod
def convert_checkpoint_to_dir(
    old_ckpt_path: str,
    output_dir: str = None,
    base_vlm_path: str = None,
):
    """
    Convert an old-format file checkpoint to the new self-contained directory format.

    Args:
        old_ckpt_path: Path to old .safetensors/.pt checkpoint file.
        output_dir: Output directory path. If None, creates a directory alongside the file.
        base_vlm_path: Path to Qwen base model (for saving config + processor).
                      If None, reads from the checkpoint's config.yaml.
    """
    import shutil
    old_ckpt_path = Path(old_ckpt_path)
    assert old_ckpt_path.is_file(), f"Old checkpoint not found: {old_ckpt_path}"

    # Determine output directory
    if output_dir is None:
        output_dir = old_ckpt_path.parent / old_ckpt_path.stem.replace("_model", "").replace("_pytorch", "")
    output_dir = Path(output_dir)
    os.makedirs(output_dir, exist_ok=True)

    # Copy weights
    weights_name = "model.safetensors" if old_ckpt_path.suffix == ".safetensors" else "pytorch_model.pt"
    shutil.copy2(str(old_ckpt_path), str(output_dir / weights_name))
    logger.info(f"[lpt0309] Copied weights to {output_dir / weights_name}")

    # Copy config.yaml and dataset_statistics.json from run dir
    run_dir = old_ckpt_path.parents[1]
    for fname, target_name in [("config.yaml", "framework_config.yaml"), ("dataset_statistics.json", "dataset_statistics.json")]:
        src = run_dir / fname
        if src.exists():
            shutil.copy2(str(src), str(output_dir / target_name))
            logger.info(f"[lpt0309] Copied {fname} -> {output_dir / target_name}")

    # Save VLM config + processor (auto-detect VLM type from config)
    if base_vlm_path is None:
        config_yaml = run_dir / "config.yaml"
        if config_yaml.exists():
            from omegaconf import OmegaConf
            cfg = OmegaConf.load(str(config_yaml))
            base_vlm_path = _get_base_vlm_path(cfg.framework)

    if base_vlm_path:
        vlm_pretrained_dir = output_dir / "vlm_pretrained"
        os.makedirs(vlm_pretrained_dir, exist_ok=True)
        try:
            from transformers import AutoConfig, AutoProcessor
            vlm_config = AutoConfig.from_pretrained(base_vlm_path, trust_remote_code=True)
            vlm_config.save_pretrained(str(vlm_pretrained_dir))
            processor = AutoProcessor.from_pretrained(base_vlm_path)
            processor.save_pretrained(str(vlm_pretrained_dir))
            logger.info(f"[lpt0309] Saved VLM config + processor to {vlm_pretrained_dir}")
        except Exception as e:
            logger.warning(f"[lpt0309] Could not save VLM config/processor from {base_vlm_path}: {e}")

    logger.info(f"[lpt0309] Conversion complete: {output_dir}")
get_action_stats classmethod
get_action_stats(unnorm_key=None)

Retrieve raw action normalization statistics.

Parameters:

Name Type Description Default
unnorm_key

Optional dataset stats key.

None

Returns:

Name Type Description
dict

Stats structure (e.g. q01, q99, mask).

Source code in AlphaBrain/model/framework/base_framework.py
@classmethod
def get_action_stats(self, unnorm_key=None):
    """
    Retrieve raw action normalization statistics.

    Args:
        unnorm_key: Optional dataset stats key.

    Returns:
        dict: Stats structure (e.g. q01, q99, mask).
    """
    unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
    return self.norm_stats[unnorm_key]["action"]
unnormalize_actions staticmethod
unnormalize_actions(normalized_actions: ndarray, action_norm_stats: Dict[str, ndarray]) -> np.ndarray

Map normalized actions (≈[-1, 1]) back to original value range.

Auto-detects normalization mode via the optional 'norm_mode' key in action_norm_stats (defaults to 'q99' for backward compatibility): - 'q99' → uses q01 / q99 bounds - 'min_max' → uses min / max bounds

Steps
  • Clamp values to [-1, 1]
  • Threshold channel index 6 to {0,1} (binary semantic)
  • Apply linear scaling for masked dimensions

Parameters:

Name Type Description Default
normalized_actions ndarray

Array shape [T, D] (or chunk length × action_dim).

required
action_norm_stats Dict[str, ndarray]

Dict containing stat arrays and optional 'norm_mode'.

required

Returns:

Type Description
ndarray

np.ndarray: Unnormalized actions (same shape as input).

Source code in AlphaBrain/model/framework/base_framework.py
@staticmethod
def unnormalize_actions(normalized_actions: np.ndarray, action_norm_stats: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Map normalized actions (≈[-1, 1]) back to original value range.

    Auto-detects normalization mode via the optional 'norm_mode' key in
    action_norm_stats (defaults to 'q99' for backward compatibility):
        - 'q99'     → uses q01 / q99 bounds
        - 'min_max' → uses min / max bounds

    Steps:
        - Clamp values to [-1, 1]
        - Threshold channel index 6 to {0,1} (binary semantic)
        - Apply linear scaling for masked dimensions

    Args:
        normalized_actions: Array shape [T, D] (or chunk length × action_dim).
        action_norm_stats: Dict containing stat arrays and optional 'norm_mode'.

    Returns:
        np.ndarray: Unnormalized actions (same shape as input).
    """
    norm_mode = action_norm_stats.get("norm_mode", "q99")
    if norm_mode == "min_max":
        ref_key_high, ref_key_low = "max", "min"
    else:
        ref_key_high, ref_key_low = "q99", "q01"
    mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats[ref_key_low], dtype=bool))
    action_high = np.array(action_norm_stats[ref_key_high])
    action_low = np.array(action_norm_stats[ref_key_low])
    normalized_actions = np.clip(normalized_actions, -1, 1)
    normalized_actions[:, 6] = np.where(normalized_actions[:, 6] < 0.5, 0, 1)
    actions = np.where(
        mask,
        0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
        normalized_actions,
    )

    return actions

config_utils

Shared configuration / utility helpers for framework components: - NamespaceWithGet: lightweight namespace behaving like a dict - OmegaConf conversion helpers - Config merging decorator for model init - Checkpoint config/statistics loading

NamespaceWithGet

Bases: SimpleNamespace

get
get(key, default=None)

Return attribute value if present, else default (dict-like API).

Parameters:

Name Type Description Default
key

Attribute name.

required
default

Fallback if attribute missing.

None

Returns:

Name Type Description
Any

Stored value or default.

Source code in AlphaBrain/model/framework/config_utils.py
def get(self, key, default=None):
    """
    Return attribute value if present, else default (dict-like API).

    Args:
        key: Attribute name.
        default: Fallback if attribute missing.

    Returns:
        Any: Stored value or default.
    """
    return getattr(self, key, default)
items
items()

Iterate (key, value) pairs like dict.items().

Returns:

Type Description

Generator[Tuple[str, Any], None, None]

Source code in AlphaBrain/model/framework/config_utils.py
def items(self):
    """
    Iterate (key, value) pairs like dict.items().

    Returns:
        Generator[Tuple[str, Any], None, None]
    """
    return ((key, getattr(self, key)) for key in self.__dict__)
to_dict
to_dict()

Recursively convert nested NamespaceWithGet objects into plain dicts.

Returns:

Name Type Description
dict

Fully materialized dictionary structure.

Source code in AlphaBrain/model/framework/config_utils.py
def to_dict(self):
    """
    Recursively convert nested NamespaceWithGet objects into plain dicts.

    Returns:
        dict: Fully materialized dictionary structure.
    """
    return {key: value.to_dict() if isinstance(value, NamespaceWithGet) else value for key, value in self.items()}

dict_to_namespace

dict_to_namespace(d)

Create an OmegaConf config from a plain dictionary.

Parameters:

Name Type Description Default
d

Input dictionary.

required

Returns:

Name Type Description
OmegaConf

DictConfig instance.

Source code in AlphaBrain/model/framework/config_utils.py
def dict_to_namespace(d):
    """
    Create an OmegaConf config from a plain dictionary.

    Args:
        d: Input dictionary.

    Returns:
        OmegaConf: DictConfig instance.
    """
    return OmegaConf.create(d)

merge_param_config

merge_param_config(init)

Decorator for init to unify config handling.

Behavior
  1. Extract 'config' kwarg / arg (path | dict | OmegaConf | namespace)
  2. Convert to OmegaConf
  3. Merge with explicitly passed init parameters (explicit overrides file)
  4. Attach merged config to self.config
  5. Call original init with merged config

Parameters:

Name Type Description Default
init

Original init function.

required

Returns:

Type Description

Wrapped initializer.

Source code in AlphaBrain/model/framework/config_utils.py
def merge_param_config(init):
    """
    Decorator for __init__ to unify config handling.

    Behavior:
        1. Extract 'config' kwarg / arg (path | dict | OmegaConf | namespace)
        2. Convert to OmegaConf
        3. Merge with explicitly passed init parameters (explicit overrides file)
        4. Attach merged config to self.config
        5. Call original __init__ with merged config

    Args:
        init: Original __init__ function.

    Returns:
        Wrapped initializer.
    """

    @functools.wraps(init)
    def wrapper(self, *args, **kwargs):
        # Map positional args to parameter names (excluding self)
        sig = inspect.signature(init)
        param_names = [name for i, (name, p) in enumerate(sig.parameters.items()) if i > 0]

        init_kwargs = {}
        for name, val in zip(param_names, args):
            init_kwargs[name] = val
        # override with explicit kwargs
        init_kwargs.update(kwargs)

        # get provided config (if any)
        provided_config = init_kwargs.get("config", None)

        loaded_cfg = _to_omegaconf(provided_config)

        # build params cfg from explicit init args (other than config)
        params = {k: v for k, v in init_kwargs.items() if k != "config"}
        params_cfg = OmegaConf.create(params) if params else OmegaConf.create({})

        # merge: loaded_cfg <- params_cfg (params override file)
        merged = OmegaConf.merge(loaded_cfg, params_cfg)

        # set on instance
        try:
            # prefer attaching OmegaConf directly
            self.config = merged
        except Exception:
            # fallback to dict
            self.config = OmegaConf.to_container(merged, resolve=True)

        # prepare kwargs for original init: ensure config is the merged OmegaConf
        call_kwargs = dict(init_kwargs)
        call_kwargs["config"] = merged

        # call original __init__ using keyword args only (safer)
        return init(self, **call_kwargs)

    return wrapper

read_model_config

read_model_config(pretrained_checkpoint)

Load global model configuration and dataset normalization statistics associated with a saved checkpoint (.pt).

Expected directory layout

/checkpoints/.pt /config.json /dataset_statistics.json

Parameters:

Name Type Description Default
pretrained_checkpoint

Path to a .pt checkpoint file.

required

Returns:

Name Type Description
tuple

global_cfg (dict): Loaded config.json contents. norm_stats (dict): Dataset statistics for (de)normalization.

Raises:

Type Description
FileNotFoundError

If checkpoint or required JSON files are missing.

AssertionError

If file suffix or structure invalid.

Source code in AlphaBrain/model/framework/config_utils.py
def read_model_config(pretrained_checkpoint):
    """
    Load global model configuration and dataset normalization statistics
    associated with a saved checkpoint (.pt).

    Expected directory layout:
        <run_dir>/checkpoints/<name>.pt
        <run_dir>/config.json
        <run_dir>/dataset_statistics.json

    Args:
        pretrained_checkpoint: Path to a .pt checkpoint file.

    Returns:
        tuple:
            global_cfg (dict): Loaded config.json contents.
            norm_stats (dict): Dataset statistics for (de)normalization.

    Raises:
        FileNotFoundError: If checkpoint or required JSON files are missing.
        AssertionError: If file suffix or structure invalid.
    """
    if os.path.isfile(pretrained_checkpoint):
        logger.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(pretrained_checkpoint))}`")

        # [Validate] Checkpoint Path should look like
        # `.../<RUN_ID>/checkpoints/<CHECKPOINT_PATH>.pt|.safetensors`
        assert checkpoint_pt.suffix in {".pt", ".safetensors"}
        run_dir = checkpoint_pt.parents[1]

        # Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint
        config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json"
        assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
        assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"

        # Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`model_id_or_path`)
        # Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json`
        with open(config_json, "r") as f:
            global_cfg = json.load(f)

        # Load Dataset Statistics for Action Denormalization
        with open(dataset_statistics_json, "r") as f:
            norm_stats = json.load(f)
    else:
        logger.error(f"❌ Pretrained checkpoint `{pretrained_checkpoint}` does not exist.")
        raise FileNotFoundError(f"Pretrained checkpoint `{pretrained_checkpoint}` does not exist.")
    return global_cfg, norm_stats

read_mode_config

read_mode_config(pretrained_checkpoint)

Same as read_model_config (legacy duplicate kept for backward compatibility).

Parameters:

Name Type Description Default
pretrained_checkpoint

Path to a .pt checkpoint file.

required

Returns:

Name Type Description
tuple

vla_cfg (dict) norm_stats (dict)

Source code in AlphaBrain/model/framework/config_utils.py
def read_mode_config(pretrained_checkpoint):
    """
    Same as read_model_config (legacy duplicate kept for backward compatibility).

    Args:
        pretrained_checkpoint: Path to a .pt checkpoint file.

    Returns:
        tuple:
            vla_cfg (dict)
            norm_stats (dict)
    """
    # lpt0309: 支持目录格式的自包含checkpoint
    if os.path.isdir(pretrained_checkpoint):
        checkpoint_dir = Path(pretrained_checkpoint)
        logger.info(f"[lpt0309] Loading from self-contained checkpoint directory `{checkpoint_dir}`")

        # lpt0309: 从checkpoint目录读取framework_config.yaml
        config_yaml = checkpoint_dir / "framework_config.yaml"
        assert config_yaml.exists(), f"[lpt0309] Missing `framework_config.yaml` in checkpoint dir `{checkpoint_dir}`"

        try:
            ocfg = OmegaConf.load(str(config_yaml))
            global_cfg = OmegaConf.to_container(ocfg, resolve=True)
        except Exception as e:
            logger.error(f"❌ Failed to load YAML config `{config_yaml}`: {e}")
            raise

        # lpt0309: 从checkpoint目录读取dataset_statistics.json
        dataset_statistics_json = checkpoint_dir / "dataset_statistics.json"
        norm_stats = {}
        if dataset_statistics_json.exists():
            with open(dataset_statistics_json, "r") as f:
                norm_stats = json.load(f)
        else:
            logger.warning(f"[lpt0309] No dataset_statistics.json found in {checkpoint_dir}, norm_stats will be empty")

        return global_cfg, norm_stats

    # origin0309: 原始文件格式加载
    elif os.path.isfile(pretrained_checkpoint):
        logger.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(pretrained_checkpoint))}`")

        # [Validate] Checkpoint Path should look like
        # `.../<RUN_ID>/checkpoints/<CHECKPOINT_PATH>.pt|.safetensors`
        assert checkpoint_pt.suffix in {".pt", ".safetensors"}
        run_dir = checkpoint_pt.parents[1]

        # Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint
        config_yaml, dataset_statistics_json = run_dir / "config.yaml", run_dir / "dataset_statistics.json"
        assert config_yaml.exists(), f"Missing `config.yaml` for `{run_dir = }`"
        assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"

        # Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`model_id_or_path`)
        # Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json`
        try:
            ocfg = OmegaConf.load(str(config_yaml))
            global_cfg = OmegaConf.to_container(ocfg, resolve=True)
        except Exception as e:
            logger.error(f"❌ Failed to load YAML config `{config_yaml}`: {e}")
            raise

        # Load Dataset Statistics for Action Denormalization
        with open(dataset_statistics_json, "r") as f:
            norm_stats = json.load(f)
    else:
        logger.error(f"❌ Pretrained checkpoint `{pretrained_checkpoint}` does not exist.")
        raise FileNotFoundError(f"Pretrained checkpoint `{pretrained_checkpoint}` does not exist.")
    return global_cfg, norm_stats

ToyVLA

ToyModel

ToyVLA — 极简 VLA 调试模型

设计目标: - 无 VLM 依赖, 无需 Qwen / LLM,秒级加载 - 接口与 QwenOFT 完全一致 (forward / predict_action 接受同样的 examples List[dict]) - 能在几百步内 overfit 小样本 → 验证训练管线是否正确

验证方法
  1. 把 N 个固定样本喂进去,train 几百步
  2. 如果 action_loss 接近 0、eval MSE 接近 0 → 管线正确
  3. 否则说明 data → forward → loss → backward 链路有 bug

Interface (与 QwenOFT 相同): examples: List[dict] - "image" : List[PIL.Image] (multi-view, 各尺寸均可) - "lang" : str - "action" : np.ndarray shape (T, action_dim)

forward(examples) → {"action_loss": scalar_tensor} predict_action(examples) → {"normalized_actions": np.ndarray (B, chunk_len, action_dim)}

TinyImageEncoder

TinyImageEncoder(img_feat_dim: int = 128)

Bases: Module

把任意尺寸 PIL Image 压成 (img_feat_dim,) 向量,纯卷积,参数量 ~10K

Source code in AlphaBrain/model/framework/ToyModel.py
def __init__(self, img_feat_dim: int = 128):
    super().__init__()
    self.img_feat_dim = img_feat_dim
    self.net = nn.Sequential(
        nn.Conv2d(3, 16, 8, stride=8),   # 224→28
        nn.ReLU(),
        nn.Conv2d(16, 32, 4, stride=4),  # 28→7
        nn.ReLU(),
        nn.AdaptiveAvgPool2d(1),         # →(32,1,1)
    )
    self.proj = nn.Linear(32, img_feat_dim)
forward
forward(x: Tensor) -> torch.Tensor

x: (B, 3, H, W) → (B, img_feat_dim)

Source code in AlphaBrain/model/framework/ToyModel.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """x: (B, 3, H, W) → (B, img_feat_dim)"""
    x = self.net(x)
    x = x.flatten(1)
    return self.proj(x)

TinyTextEncoder

TinyTextEncoder(vocab_size: int = 256, text_feat_dim: int = 64)

Bases: Module

Source code in AlphaBrain/model/framework/ToyModel.py
def __init__(self, vocab_size: int = 256, text_feat_dim: int = 64):
    super().__init__()
    self.text_feat_dim = text_feat_dim
    self.vocab_size = vocab_size
    self.emb = nn.Embedding(vocab_size, text_feat_dim)
forward
forward(texts: List[str]) -> torch.Tensor

texts: List[str] → (B, text_feat_dim)

Source code in AlphaBrain/model/framework/ToyModel.py
def forward(self, texts: List[str]) -> torch.Tensor:
    """texts: List[str] → (B, text_feat_dim)"""
    # 把文本 hash 成 vocab_size 以内的 index,简单但可分辨
    indices = []
    for t in texts:
        h = int(hashlib.md5(t.encode()).hexdigest(), 16) % self.vocab_size
        indices.append(h)
    idx = torch.tensor(indices, device=self.emb.weight.device)
    return self.emb(idx)  # (B, text_feat_dim)

ToyVLA

ToyVLA(config=None, **kwargs)

Bases: PreTrainedModel

极简 VLA 调试模型。 - 用 TinyImageEncoder + TinyTextEncoder 代替 Qwen VLM - 用小 MLP 做动作回归 - 整体 < 200K 参数,单卡几秒即可 overfit 小 batch

Source code in AlphaBrain/model/framework/ToyModel.py
def __init__(self, config=None, **kwargs):
    super().__init__(PretrainedConfig())
    self.toy_config = config  # OmegaConf

    # 从 config 中读取超参(提供合理默认值)
    am_cfg = config.framework.action_model if hasattr(config, "framework") else None
    self.action_dim      = getattr(am_cfg, "action_dim", 7)           if am_cfg else 7
    self.future_window   = getattr(am_cfg, "future_action_window_size", 15) if am_cfg else 15
    self.past_window     = getattr(am_cfg, "past_action_window_size",  0)   if am_cfg else 0
    self.chunk_len       = self.past_window + 1 + self.future_window

    img_feat_dim  = 128
    text_feat_dim = 64
    fuse_dim      = img_feat_dim + text_feat_dim  # 192

    self.img_encoder  = TinyImageEncoder(img_feat_dim)
    self.text_encoder = TinyTextEncoder(text_feat_dim=text_feat_dim)

    # 动作预测头: fuse_dim → (chunk_len * action_dim)
    self.action_head = nn.Sequential(
        nn.Linear(fuse_dim, 256),
        nn.ReLU(),
        nn.LayerNorm(256),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, self.chunk_len * self.action_dim),
    )
    self.l1_loss = nn.L1Loss()

    logger.info(
        f"[ToyVLA] Built: action_dim={self.action_dim}, chunk_len={self.chunk_len}, "
        f"params={sum(p.numel() for p in self.parameters())/1e3:.1f}K"
    )
forward
forward(examples: List[dict], **kwargs) -> dict

Returns:

Type Description
dict

{"action_loss": scalar tensor}

Source code in AlphaBrain/model/framework/ToyModel.py
def forward(self, examples: List[dict], **kwargs) -> dict:
    """
    Returns:
        {"action_loss": scalar tensor}
    """
    fused = self._encode_batch(examples)                      # (B, fuse_dim)
    raw   = self.action_head(fused)                           # (B, chunk_len*action_dim)
    pred  = raw.view(-1, self.chunk_len, self.action_dim)     # (B, chunk_len, action_dim)

    # 标签
    actions = np.array([ex["action"] for ex in examples])    # (B, T, D)
    tgt = torch.tensor(actions, device=pred.device, dtype=pred.dtype)
    # 取最后 chunk_len 步
    tgt = tgt[:, -self.chunk_len:, :]                         # (B, chunk_len, action_dim)

    loss = self.l1_loss(pred, tgt)
    return {"action_loss": loss}
predict_action
predict_action(examples: List[dict], **kwargs) -> dict

Returns:

Type Description
dict

{"normalized_actions": np.ndarray (B, chunk_len, action_dim)}

Source code in AlphaBrain/model/framework/ToyModel.py
def predict_action(self, examples: List[dict], **kwargs) -> dict:
    """
    Returns:
        {"normalized_actions": np.ndarray (B, chunk_len, action_dim)}
    """
    with torch.no_grad():
        fused = self._encode_batch(examples)
        raw   = self.action_head(fused)
        pred  = raw.view(-1, self.chunk_len, self.action_dim)
    return {"normalized_actions": pred.cpu().float().numpy()}

ACT

ACT

ACT — Action Chunking Transformers (standalone implementation)

Reference

Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware Zhao et al., RSS 2023

Architecture
  • ResNet18 visual encoder (per camera view)
  • CVAE encoder: (robot_state, action_chunk) → z (training only; z=0 at inference)
  • Transformer encoder: [z_token, img_tokens, state_token] → memory
  • Transformer decoder: query_embed → action_chunk

Interface (same as QwenOFT / ToyVLA): examples: List[dict] - "image" : List[PIL.Image] (multi-view, any size) - "lang" : str (ignored during action prediction, kept for API compat) - "action" : np.ndarray shape (T, action_dim) - "state" : np.ndarray shape (T_state, state_dim) [optional]

forward(examples) → {"action_loss": tensor} predict_action(examples) → {"normalized_actions": np.ndarray (B, chunk_len, action_dim)}

CVAEEncoder

CVAEEncoder(state_dim: int, action_dim: int, hidden_dim: int, latent_dim: int, num_heads: int = 4, num_layers: int = 2)

Bases: Module

Encodes (robot_state, action_chunk) → (mu, log_var). Inputs are projected to hidden_dim then fused through a small Transformer encoder.

Source code in AlphaBrain/model/framework/ACT.py
def __init__(self, state_dim: int, action_dim: int, hidden_dim: int, latent_dim: int, num_heads: int = 4, num_layers: int = 2):
    super().__init__()
    self.state_proj  = nn.Linear(state_dim,  hidden_dim)
    self.action_proj = nn.Linear(action_dim, hidden_dim)
    self.cls_token   = nn.Parameter(torch.zeros(1, 1, hidden_dim))
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim * 4,
        batch_first=True, dropout=0.1, activation="gelu", norm_first=True,
    )
    self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    self.mu_head  = nn.Linear(hidden_dim, latent_dim)
    self.var_head = nn.Linear(hidden_dim, latent_dim)
forward
forward(state: Tensor, action_chunk: Tensor) -> Tuple[torch.Tensor, torch.Tensor]

state: (B, state_dim) action_chunk: (B, chunk_len, action_dim) Returns: mu, log_var each (B, latent_dim)

Source code in AlphaBrain/model/framework/ACT.py
def forward(self, state: torch.Tensor, action_chunk: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    state:        (B, state_dim)
    action_chunk: (B, chunk_len, action_dim)
    Returns: mu, log_var  each (B, latent_dim)
    """
    B = state.shape[0]
    state_emb  = self.state_proj(state).unsqueeze(1)       # (B, 1, H)
    action_emb = self.action_proj(action_chunk)             # (B, T, H)
    cls = self.cls_token.expand(B, -1, -1)                 # (B, 1, H)
    seq = torch.cat([cls, state_emb, action_emb], dim=1)   # (B, 2+T, H)
    out = self.encoder(seq)                                 # (B, 2+T, H)
    cls_out = out[:, 0, :]                                  # (B, H)
    return self.mu_head(cls_out), self.var_head(cls_out)

ACTModel

ACTModel(config=None, **kwargs)

Bases: PreTrainedModel

Standalone ACT (Action Chunking Transformers) model.

Key design choices vs. paper: - Use ResNet18 (torchvision) instead of ResNet18 with backbone unfreezing - Replace FiLM conditioning with simple token concatenation - Use PyTorch native Transformer encoder / decoder

Source code in AlphaBrain/model/framework/ACT.py
def __init__(self, config=None, **kwargs):
    super().__init__(PretrainedConfig())
    self.act_config = config

    # -------- hyper-parameters from config --------
    am_cfg = config.framework.action_model if hasattr(config, "framework") else None

    self.action_dim      = getattr(am_cfg, "action_dim", 7)         if am_cfg else 7
    self.state_dim       = getattr(am_cfg, "state_dim", 8)          if am_cfg else 8
    self.hidden_dim      = getattr(am_cfg, "hidden_dim", 256)       if am_cfg else 256
    self.latent_dim      = getattr(am_cfg, "latent_dim", 32)        if am_cfg else 32
    self.num_heads       = getattr(am_cfg, "num_heads", 8)          if am_cfg else 8
    self.num_enc_layers  = getattr(am_cfg, "num_enc_layers", 4)     if am_cfg else 4
    self.num_dec_layers  = getattr(am_cfg, "num_dec_layers", 7)     if am_cfg else 7
    self.kl_weight       = getattr(am_cfg, "kl_weight", 10.0)       if am_cfg else 10.0
    self.chunk_len       = (
        getattr(am_cfg, "future_action_window_size", 7) + 1
        if am_cfg else 8
    )
    self.n_views         = getattr(am_cfg, "n_views", 2)            if am_cfg else 2

    H = self.hidden_dim

    # -------- Visual backbone (ResNet18, per view) --------
    backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    # Remove final FC, keep up-to avgpool → output: (B*n_views, 512)
    self.visual_backbone = nn.Sequential(*list(backbone.children())[:-1])
    self.visual_proj     = nn.Linear(512, H)

    # -------- State projection --------
    self.state_proj = nn.Linear(self.state_dim, H)

    # -------- Latent z projection --------
    self.latent_proj = nn.Linear(self.latent_dim, H)

    # -------- CVAE encoder (training only) --------
    self.cvae_encoder = CVAEEncoder(
        state_dim=self.state_dim,
        action_dim=self.action_dim,
        hidden_dim=H,
        latent_dim=self.latent_dim,
        num_heads=min(self.num_heads, H // 32),
        num_layers=2,
    )

    # -------- Positional encodings --------
    # We have:  1 (z) + n_views (img) + 1 (state) = total context tokens
    max_ctx = 1 + self.n_views + 1
    self.pos_enc = nn.Embedding(max_ctx + self.chunk_len, H)

    # -------- Main Transformer encoder --------
    enc_layer = nn.TransformerEncoderLayer(
        d_model=H, nhead=self.num_heads, dim_feedforward=H * 4,
        batch_first=True, dropout=0.1, activation="gelu", norm_first=True,
    )
    self.transformer_encoder = nn.TransformerEncoder(enc_layer, num_layers=self.num_enc_layers)

    # -------- Action query embeddings --------
    self.query_embed = nn.Embedding(self.chunk_len, H)

    # -------- Transformer decoder --------
    dec_layer = nn.TransformerDecoderLayer(
        d_model=H, nhead=self.num_heads, dim_feedforward=H * 4,
        batch_first=True, dropout=0.1, activation="gelu", norm_first=True,
    )
    self.transformer_decoder = nn.TransformerDecoder(dec_layer, num_layers=self.num_dec_layers)

    # -------- Action head --------
    self.action_head = nn.Linear(H, self.action_dim)

    self.l1_loss = nn.L1Loss()

    n_params = sum(p.numel() for p in self.parameters()) / 1e6
    logger.info(
        f"[ACT] Built: action_dim={self.action_dim}, chunk_len={self.chunk_len}, "
        f"hidden_dim={H}, latent_dim={self.latent_dim}, "
        f"enc_layers={self.num_enc_layers}, dec_layers={self.num_dec_layers}, "
        f"n_views={self.n_views}, params={n_params:.1f}M"
    )
forward
forward(examples: List[dict], **kwargs) -> dict

Returns:

Type Description
dict

dict with keys: action_loss (scalar) kl_loss (scalar)

Source code in AlphaBrain/model/framework/ACT.py
def forward(self, examples: List[dict], **kwargs) -> dict:
    """
    Returns:
        dict with keys:
            action_loss  (scalar)
            kl_loss      (scalar)
    """
    device = next(self.parameters()).device
    dtype  = next(self.parameters()).dtype
    B = len(examples)

    # 1. Extract ground-truth actions
    actions_np = np.array([np.array(ex["action"], dtype=np.float32) for ex in examples])  # (B, T, D)
    T = actions_np.shape[1]
    # Align chunk: take the first chunk_len steps if T >= chunk_len, else pad with last
    if T >= self.chunk_len:
        actions_np = actions_np[:, :self.chunk_len, :]
    else:
        pad = np.repeat(actions_np[:, -1:, :], self.chunk_len - T, axis=1)
        actions_np = np.concatenate([actions_np, pad], axis=1)
    actions = torch.tensor(actions_np, device=device, dtype=dtype)  # (B, chunk_len, D)

    # 2. Visual features
    imgs_batch = [ex["image"] for ex in examples]
    # Pad or truncate to n_views
    imgs_batch = [views[:self.n_views] + [views[-1]] * max(0, self.n_views - len(views)) for views in imgs_batch]
    img_feats  = self._extract_visual_features(imgs_batch)           # (B, n_views, H)

    # 3. Robot state
    state_raw  = self._get_state(examples)                           # (B, state_dim)
    state_feat = self.state_proj(state_raw)                          # (B, H)

    # 4. CVAE encoder → z
    mu, log_var = self.cvae_encoder(state_raw, actions)              # each (B, latent_dim)
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    z_raw = mu + eps * std                                           # (B, latent_dim)
    z = self.latent_proj(z_raw)                                      # (B, H)

    # 5. Transformer encoder
    ctx = self._build_encoder_input(img_feats, state_feat, z)        # (B, 1+V+1, H)
    memory = self.transformer_encoder(ctx)                           # (B, 1+V+1, H)

    # 6. Transformer decoder → action predictions
    n_ctx = ctx.shape[1]
    q_idx = torch.arange(self.chunk_len, device=device)
    queries = self.query_embed(q_idx).unsqueeze(0).expand(B, -1, -1)  # (B, chunk_len, H)
    # add positional embeddings to queries
    q_pos_idx = torch.arange(n_ctx, n_ctx + self.chunk_len, device=device)
    queries = queries + self.pos_enc(q_pos_idx).unsqueeze(0)
    decoded = self.transformer_decoder(queries, memory)               # (B, chunk_len, H)
    pred_actions = self.action_head(decoded)                          # (B, chunk_len, D)

    # 7. L1 reconstruction loss
    l1 = self.l1_loss(pred_actions, actions)

    # 8. KL divergence loss
    kl = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())

    action_loss = l1 + self.kl_weight * kl

    return {
        "action_loss": action_loss,
        "l1_loss": l1.detach(),
        "kl_loss": kl.detach(),
    }
predict_action
predict_action(examples: List[dict] = None, batch_images: List[List] = None, instructions: List[str] = None, states: ndarray = None, **kwargs) -> dict

Accepts two input formats:

  1. examples format (train / debug): examples = [{"image": [PIL,...], "lang": str, "state": np.ndarray}, ...]

  2. Flat format (from websocket server / M1Inference): batch_images = [[img0, img1], ...] (B × n_views, np.ndarray or PIL) instructions = ["task description", ...] states = np.ndarray (B, T, state_dim) or (B, state_dim)

Returns:

Name Type Description
dict dict

normalized_actions: np.ndarray (B, chunk_len, action_dim)

Source code in AlphaBrain/model/framework/ACT.py
@torch.inference_mode()
def predict_action(
    self,
    examples: List[dict] = None,
    # ---- flat format (from websocket client / M1Inference) ----
    batch_images: List[List] = None,   # B × n_views, each element is np.ndarray or PIL
    instructions: List[str] = None,
    states: np.ndarray = None,         # (B, T, state_dim) or (B, state_dim)
    **kwargs,
) -> dict:
    """
    Accepts two input formats:

    1. examples format (train / debug):
           examples = [{"image": [PIL,...], "lang": str, "state": np.ndarray}, ...]

    2. Flat format (from websocket server / M1Inference):
           batch_images = [[img0, img1], ...]   (B × n_views, np.ndarray or PIL)
           instructions = ["task description", ...]
           states       = np.ndarray (B, T, state_dim) or (B, state_dim)

    Returns:
        dict:
            normalized_actions: np.ndarray  (B, chunk_len, action_dim)
    """
    # ---- convert flat format → examples ----
    if examples is None:
        assert batch_images is not None, "Either examples or batch_images must be provided"
        B = len(batch_images)
        examples = []
        for i in range(B):
            imgs = []
            for img in batch_images[i]:
                if isinstance(img, np.ndarray):
                    imgs.append(Image.fromarray(img.astype(np.uint8)))
                else:
                    imgs.append(img)
            state = None
            if states is not None:
                s = np.array(states[i])
                if s.ndim == 2:
                    s = s[-1]   # take the most recent timestep
                state = s[np.newaxis, :]   # (1, state_dim)
            examples.append({
                "image":  imgs,
                "lang":   instructions[i] if instructions else "",
                "state":  state,
                "action": np.zeros((self.chunk_len, self.action_dim), dtype=np.float32),
            })

    device = next(self.parameters()).device
    dtype  = next(self.parameters()).dtype
    B = len(examples)

    # visual features
    imgs_batch = [ex["image"] for ex in examples]
    imgs_batch = [views[:self.n_views] + [views[-1]] * max(0, self.n_views - len(views)) for views in imgs_batch]
    img_feats  = self._extract_visual_features(imgs_batch)

    # state
    state_raw  = self._get_state(examples)
    state_feat = self.state_proj(state_raw)

    # z = 0 at inference (mean of prior)
    z_raw = torch.zeros(B, self.latent_dim, device=device, dtype=dtype)
    z     = self.latent_proj(z_raw)

    # encoder
    ctx    = self._build_encoder_input(img_feats, state_feat, z)
    memory = self.transformer_encoder(ctx)

    # decoder
    n_ctx = ctx.shape[1]
    q_idx = torch.arange(self.chunk_len, device=device)
    queries = self.query_embed(q_idx).unsqueeze(0).expand(B, -1, -1)
    q_pos_idx = torch.arange(n_ctx, n_ctx + self.chunk_len, device=device)
    queries = queries + self.pos_enc(q_pos_idx).unsqueeze(0)
    decoded     = self.transformer_decoder(queries, memory)
    pred_actions = self.action_head(decoded)                           # (B, chunk_len, D)

    return {"normalized_actions": pred_actions.cpu().float().numpy()}

CosmosPolicy

CosmosPolicy

CosmosPolicy Framework

A video diffusion model (Cosmos Predict2 2B DiT) fine-tuned for robot policy prediction. Unlike VLM-based frameworks (QwenOFT, QwenGR00T), this uses latent-space diffusion: - WAN 2.1 VAE encodes images to latent space (frozen) - MiniTrainDIT backbone denoises latent sequences (trainable) - Actions/proprio/value are injected into latent frames - T5 text embeddings provide language conditioning (precomputed)

Latent frame layout (LIBERO, state_t=9): [blank, curr_proprio, curr_wrist, curr_primary, action, future_proprio, future_wrist, future_primary, value]

CosmosPolicy

CosmosPolicy(config=None, **kwargs)

Bases: BaseFramework

Cosmos-Policy: latent-space video diffusion for robot action prediction.

Training: VAE encode → inject action/proprio → diffusion loss on latent sequence Inference: multi-step denoising → extract action from latent frame

Source code in AlphaBrain/model/framework/CosmosPolicy.py
def __init__(self, config=None, **kwargs):
    super().__init__()
    self.config = config
    cp_cfg = config.framework.cosmos_policy

    # --- Action / sequence config ---
    self.action_dim = cp_cfg.action_dim          # 7 (LIBERO)
    self.chunk_size = cp_cfg.chunk_size           # 16 (LIBERO)
    self.proprio_dim = cp_cfg.proprio_dim         # 9 (LIBERO)
    self.state_t = cp_cfg.state_t                 # 9 (LIBERO)

    # Latent frame indices
    self.blank_idx = 0
    self.curr_proprio_idx = 1
    self.curr_wrist_idx = 2
    self.curr_primary_idx = 3
    self.action_idx = 4
    self.future_proprio_idx = 5
    self.future_wrist_idx = 6
    self.future_primary_idx = 7
    self.value_idx = 8

    self.condition_frame_indices = [0, 1, 2, 3]
    self.prediction_frame_indices = [4, 5, 6, 7, 8]

    # Loss config
    loss_cfg = getattr(cp_cfg, 'loss', None)
    self.action_loss_multiplier = getattr(loss_cfg, 'action_loss_multiplier', 1.0) if loss_cfg else 1.0
    self.world_model_loss_weight = getattr(loss_cfg, 'world_model_loss_weight', 1.0) if loss_cfg else 1.0
    self.value_loss_weight = getattr(loss_cfg, 'value_loss_weight', 0.0) if loss_cfg else 0.0
    self.loss_scale = getattr(loss_cfg, 'loss_scale', 10.0) if loss_cfg else 10.0
    self.sigma_data = getattr(loss_cfg, 'sigma_data', 1.0) if loss_cfg else 1.0
    self.adjust_video_noise = getattr(loss_cfg, 'adjust_video_noise', True) if loss_cfg else True

    # --- 1. VAE Tokenizer (frozen) ---
    # Force deterministic cuDNN for cross-process reproducibility
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.allow_tf32 = False
    torch.backends.cuda.matmul.allow_tf32 = False

    # Use VLA-Engine own WanVAE (loads same .pth weights, no cosmos_policy dependency).
    from AlphaBrain.model.modules.world_model.cosmos.wan_vae import WanVAEWrapper
    pretrained_dir = cp_cfg.checkpoint.pretrained_dir
    if pretrained_dir is None:
        pretrained_dir = os.path.join(
            os.environ.get('PRETRAINED_MODELS_DIR', 'data/pretrained_models'),
            'Cosmos-Predict2-2B-Video2World'
        )
    self.vae = WanVAEWrapper(
        pretrained_dir=pretrained_dir,
        dtype=torch.bfloat16,
        device="cpu",
        temporal_window=16,
    )

    # --- 2. DIT Backbone (trainable) ---
    # Use VLA own MinimalV1LVGDiT with TE backend (no cosmos_policy dependency).
    # It inherits MiniTrainDIT and adds condition_video_input_mask concatenation
    # in forward(), plus timestep_scale support.
    from AlphaBrain.model.modules.world_model.cosmos.official_dit import MinimalV1LVGDiT
    dit_cfg = cp_cfg.dit
    self.dit = MinimalV1LVGDiT(
        max_img_h=getattr(dit_cfg, 'max_img_h', 240),
        max_img_w=getattr(dit_cfg, 'max_img_w', 240),
        max_frames=getattr(dit_cfg, 'max_frames', 128),
        in_channels=16,   # state_ch=16; MinimalV1LVGDiT adds +1 for condition mask
        out_channels=16,
        patch_spatial=getattr(dit_cfg, 'patch_spatial', 2),
        patch_temporal=getattr(dit_cfg, 'patch_temporal', 1),
        concat_padding_mask=True,
        model_channels=getattr(dit_cfg, 'model_channels', 2048),
        num_blocks=getattr(dit_cfg, 'num_blocks', 28),
        num_heads=getattr(dit_cfg, 'num_heads', 16),
        crossattn_emb_channels=getattr(dit_cfg, 'crossattn_emb_channels', 1024),
        use_crossattn_projection=getattr(dit_cfg, 'use_crossattn_projection', False),
        atten_backend="minimal_a2a",  # Official attention backend (local copy, no cosmos_policy dep)
        pos_emb_cls=getattr(dit_cfg, 'pos_emb_cls', "rope3d"),
        mlp_ratio=getattr(dit_cfg, 'mlp_ratio', 4.0),
        use_adaln_lora=True,
        adaln_lora_dim=256,
        use_wan_fp32_strategy=getattr(dit_cfg, 'use_wan_fp32_strategy', False),
        pos_emb_learnable=getattr(dit_cfg, 'pos_emb_learnable', True),
        rope_enable_fps_modulation=getattr(dit_cfg, 'rope_enable_fps_modulation', False),
        rope_h_extrapolation_ratio=getattr(dit_cfg, 'rope_h_extrapolation_ratio', 3.0),
        rope_w_extrapolation_ratio=getattr(dit_cfg, 'rope_w_extrapolation_ratio', 3.0),
    )

    # --- 3. SDE (noise schedule) ---
    from AlphaBrain.model.modules.world_model.cosmos.hybrid_edm_sde import HybridEDMSDE
    sde_cfg = cp_cfg.sde
    self.sde = HybridEDMSDE(
        p_mean=getattr(sde_cfg, 'p_mean', 1.3862943611198906),
        p_std=getattr(sde_cfg, 'p_std', 1.2),
        sigma_max=getattr(sde_cfg, 'sigma_max', 200.0),
        sigma_min=getattr(sde_cfg, 'sigma_min', 0.01),
        hybrid_sigma_distribution=True,
    )

    # --- 4. Sampler (inference) ---
    from AlphaBrain.model.modules.world_model.cosmos.cosmos_sampler import CosmosPolicySampler
    self.sampler = CosmosPolicySampler()

    # Inference config
    inf_cfg = getattr(cp_cfg, 'inference', None)
    self.action_num_steps = getattr(inf_cfg, 'action_num_steps', 5) if inf_cfg else 5
    self.inference_sigma_min = getattr(inf_cfg, 'sigma_min', 4.0) if inf_cfg else 4.0
    self.inference_sigma_max = getattr(inf_cfg, 'sigma_max', 80.0) if inf_cfg else 80.0

    # --- 5. T5 embeddings (precomputed) ---
    self.t5_embeddings = None  # loaded lazily
    self.t5_embeddings_path = getattr(cp_cfg, 't5_embeddings_path', None)

    # --- 6. Conditioning ---
    self.sigma_conditional = getattr(cp_cfg, 'sigma_conditional', 0.0)

    # Load pretrained DIT weights
    ckpt_path = getattr(cp_cfg.checkpoint, 'load_path', None)
    if ckpt_path is None:
        ckpt_path = os.path.join(pretrained_dir, 'model-480p-16fps.pt')
    if ckpt_path and os.path.exists(ckpt_path):
        self._load_dit_checkpoint(ckpt_path)

    # Convert DIT to bf16, matching original's on_train_start():
    #   self.net = self.net.to(**self.tensor_kwargs)
    # where tensor_kwargs = {"device": "cuda", "dtype": bfloat16}.
    # The original always runs DiT in bf16 for both training and inference.
    # We convert dtype here (device placement handled by trainer or server).
    self.dit = self.dit.to(dtype=torch.bfloat16)

    logger.info(f"[CosmosPolicy] Initialized: state_t={self.state_t}, "
                 f"action_dim={self.action_dim}, chunk_size={self.chunk_size}")
forward
forward(examples, **kwargs)

Training forward pass: diffusion denoising loss on full 9-frame latent sequence.

Parameters:

Name Type Description Default
examples

batched dict from DataLoader (each value is a stacked tensor), or list of dicts (legacy).

required

Returns:

Type Description

{"action_loss": total_loss}

Source code in AlphaBrain/model/framework/CosmosPolicy.py
def forward(self, examples, **kwargs):
    """
    Training forward pass: diffusion denoising loss on full 9-frame latent sequence.

    Args:
        examples: batched dict from DataLoader (each value is a stacked tensor),
                  or list of dicts (legacy).

    Returns:
        {"action_loss": total_loss}
    """
    from AlphaBrain.model.modules.world_model.cosmos.latent_utils import (
        replace_latent_with_action_chunk,
        replace_latent_with_proprio,
    )

    device = next(self.dit.parameters()).device

    # Handle both batched dict (from DataLoader) and list of dicts (legacy)
    if isinstance(examples, dict):
        batch = examples
    else:
        batch = {k: torch.stack([ex[k] for ex in examples]) for k in examples[0]}

    B = batch["video"].shape[0]

    # --- Step 1: Gather batch data ---
    videos = batch["video"].to(device)                          # (B, 3, 33, 224, 224) uint8
    actions = batch["actions"].to(device)                       # (B, chunk_size, action_dim)
    proprios = batch["proprio"].to(device)                      # (B, proprio_dim)
    future_proprios = batch["future_proprio"].to(device)
    t5_embs = batch["t5_text_embeddings"].to(device)            # (B, 512, 1024)
    values = batch.get("value_function_return", torch.zeros(B)).to(device=device, dtype=torch.float32)

    # Masks for loss computation
    rollout_masks = batch.get("rollout_data_mask", torch.zeros(B)).to(device=device, dtype=torch.float32)
    wm_masks = batch.get("world_model_sample_mask", torch.zeros(B)).to(device=device, dtype=torch.float32)
    vf_masks = batch.get("value_function_sample_mask", torch.zeros(B)).to(device=device, dtype=torch.float32)

    # --- Step 2: VAE encode → latent ---
    with torch.no_grad():
        # Normalize uint8 [0, 255] → float [-1, 1]
        # IMPORTANT: Normalize in bf16 to match original cosmos-policy's
        # _normalize_video_databatch_inplace: video.to(dtype=bf16) / 127.5 - 1.0
        video_norm = videos.to(dtype=torch.bfloat16) / 127.5 - 1.0
        # Match official pipeline memory layout (channels_last_3d)
        video_norm = video_norm.to(memory_format=torch.channels_last_3d)
        x0 = self.vae.encode(video_norm)  # (B, 16, 9, 28, 28)

    # Apply sigma_data scaling (matches official: encode(x) * sigma_data)
    # The VAE already does per-channel mean/std normalization internally;
    # this additional scaling aligns the latent magnitude with EDM preconditioning.
    x0 = x0 * self.sigma_data
    x0 = x0.to(dtype=torch.bfloat16)

    # --- Step 3: Inject action/proprio/value into latent frames ---
    action_indices = torch.full((B,), self.action_idx, device=device, dtype=torch.long)
    curr_proprio_indices = torch.full((B,), self.curr_proprio_idx, device=device, dtype=torch.long)
    future_proprio_indices = torch.full((B,), self.future_proprio_idx, device=device, dtype=torch.long)

    x0 = replace_latent_with_action_chunk(
        x0, actions.to(x0.dtype), action_indices
    )
    x0 = replace_latent_with_proprio(
        x0, proprios.to(x0.dtype), curr_proprio_indices
    )
    x0 = replace_latent_with_proprio(
        x0, future_proprios.to(x0.dtype), future_proprio_indices
    )

    # Value: expand scalar to fill latent volume at value_idx
    value_flat = values.to(x0.dtype).view(B, 1, 1, 1).expand(B, 16, 28, 28)
    x0[:, :, self.value_idx, :, :] = value_flat

    # --- Step 4: Save clean condition frames for denoise ---
    # Official pattern: noise ALL frames, then replace condition frames inside denoise
    n_cond = len(self.condition_frame_indices)
    gt_frames = x0[:, :, :n_cond, :, :].clone()

    # --- Step 5: Sample sigma and noise ---
    sigma = self.sde.sample_t(B, device=device).to(x0.dtype)

    # Video noise multiplier: sigma *= sqrt(state_t) when adjust_video_noise=True
    # Matches original text2world_model.py behavior (state_t=9 → sigma *= 3.0)
    if self.adjust_video_noise:
        sigma = sigma * (self.state_t ** 0.5)

    epsilon = torch.randn_like(x0)

    # Apply noise to ALL frames (official pattern: marginal_prob returns x0, sigma)
    # xt = x0 + sigma * epsilon
    sigma_5d = sigma.view(B, 1, 1, 1, 1)
    xt = x0 + sigma_5d * epsilon

    # --- Step 6: Denoise with EDM preconditioning ---
    crossattn_emb = t5_embs.to(dtype=torch.bfloat16)
    # padding_mask in pixel space (224x224), matching original
    padding_mask = torch.zeros(B, 1, 224, 224, device=device, dtype=xt.dtype)
    fps = torch.tensor([16] * B, device=device, dtype=torch.float32)  # Must match original (always 16)

    # EDM preconditioning with official condition frame handling:
    # - condition frames replaced with gt / sigma_data in network input
    # - condition frames use sigma_conditional c_noise
    # - output condition frames replaced with clean gt
    n_cond = len(self.condition_frame_indices)
    x0_pred = self._denoise(
        xt, sigma, crossattn_emb, fps, padding_mask,
        gt_frames=gt_frames, n_cond_frames=n_cond
    )

    # --- Step 7: Compute loss ---
    # RectifiedFlow loss weight: (1 + sigma)^2 / sigma^2
    loss_weight = (1 + sigma) ** 2 / sigma ** 2
    loss_weight = loss_weight.view(B, 1, 1, 1, 1)

    # MSE between prediction and ground truth
    pred_mse = (x0_pred - x0) ** 2  # (B, C, T, H, W)

    # Official LIBERO: no per-sample loss masking (all mask flags default False).
    # All 9 latent frames contribute equally to loss.
    # Condition frames contribute ~0 since x0_pred == gt (replaced in denoise).
    edm_loss = pred_mse * loss_weight
    total_loss = edm_loss.mean() * self.loss_scale

    # Log per-component losses for monitoring (no gradient)
    with torch.no_grad():
        action_mse = pred_mse[:, :, self.action_idx, :, :].mean()
        cond_mse = pred_mse[:, :, :len(self.condition_frame_indices), :, :].mean()

    return {"action_loss": total_loss, "action_mse": action_mse, "cond_mse": cond_mse}
predict_action
predict_action(examples=None, batch_images=None, instructions=None, **kwargs)

Inference: multi-step denoising to predict action chunk.

Matches original cosmos-policy get_action() flow: 1. Build full 33-frame video (with placeholders for prediction frames) 2. VAE encode → 9 latent frames 3. Inject normalized proprio into frame 1 4. Save condition frames (0-3), replace prediction frames (4-8) with noise 5. Multi-step denoising 6. Extract action from denoised latent at action_idx

Parameters:

Name Type Description Default
examples

list of dicts with image, wrist_image, lang, proprio

None
batch_images

alternative — list of PIL images

None
instructions

alternative — list of strings

None

Returns:

Type Description

{"normalized_actions": np.ndarray of shape (B, chunk_size, action_dim)}

Source code in AlphaBrain/model/framework/CosmosPolicy.py
@torch.inference_mode()
def predict_action(self, examples=None, batch_images=None,
                   instructions=None, **kwargs):
    """
    Inference: multi-step denoising to predict action chunk.

    Matches original cosmos-policy get_action() flow:
    1. Build full 33-frame video (with placeholders for prediction frames)
    2. VAE encode → 9 latent frames
    3. Inject normalized proprio into frame 1
    4. Save condition frames (0-3), replace prediction frames (4-8) with noise
    5. Multi-step denoising
    6. Extract action from denoised latent at action_idx

    Args:
        examples: list of dicts with image, wrist_image, lang, proprio
        batch_images: alternative — list of PIL images
        instructions: alternative — list of strings

    Returns:
        {"normalized_actions": np.ndarray of shape (B, chunk_size, action_dim)}
    """
    from AlphaBrain.model.modules.world_model.cosmos.latent_utils import (
        replace_latent_with_proprio,
    )

    device = next(self.dit.parameters()).device

    # Parse inputs
    if examples is not None:
        primary_images = [ex["image"] for ex in examples]
        wrist_images = [ex.get("wrist_image") for ex in examples]
        instructions = [ex["lang"] for ex in examples]
        proprios = [ex["proprio"] for ex in examples]
    else:
        primary_images = batch_images
        wrist_images = kwargs.get("wrist_images", [None] * len(primary_images))
        proprios = kwargs.get("proprios", [np.zeros(self.proprio_dim)] * len(primary_images))

    B = len(primary_images)

    # Preprocess images to match official cosmos-policy inference pipeline:
    # JPEG compression (quality=95) + 90% center crop + resize to 224x224
    primary_images = [self._preprocess_image(img) for img in primary_images]
    wrist_images = [self._preprocess_image(img) if img is not None else None for img in wrist_images]

    # --- Step 1: Build full 33-frame video (matching original get_action) ---
    full_video = self._build_full_video(primary_images, wrist_images)  # (B, 3, 33, 224, 224)
    full_video = full_video.to(device)

    # --- Step 2: VAE encode full video → 9 latent frames ---
    # IMPORTANT: Normalize in bf16 to match original cosmos-policy's
    # _normalize_video_databatch_inplace: video.to(dtype=bf16) / 127.5 - 1.0
    # Normalizing in float32 then converting to bf16 gives different rounding.
    video_norm = full_video.to(dtype=torch.bfloat16) / 127.5 - 1.0
    # Match official pipeline memory layout (channels_last_3d)
    video_norm = video_norm.to(memory_format=torch.channels_last_3d)

    x0 = self.vae.encode(video_norm)  # (B, 16, 9, 28, 28)
    x0 = x0 * self.sigma_data  # Apply sigma_data scaling
    x0 = x0.to(torch.bfloat16)

    # --- Step 3: Inject normalized proprio into latent frame 1 ---
    proprio_array = np.array(proprios)
    # Normalize proprio to [-1, 1] using dataset stats (matches original)
    proprio_array = self._normalize_proprio(proprio_array)
    proprio_tensor = torch.tensor(proprio_array, device=device, dtype=x0.dtype)
    x0 = replace_latent_with_proprio(
        x0, proprio_tensor,
        torch.full((B,), self.curr_proprio_idx, device=device, dtype=torch.long)
    )

    # --- Step 4: Save condition frames, generate full-frame noise ---
    n_cond = len(self.condition_frame_indices)
    gt_frames = x0[:, :, :n_cond, :, :].clone()  # (B, 16, 4, 28, 28)

    B_lat, C, T_full, H_lat, W_lat = x0.shape

    # Generate noise for ALL frames (matching original: arch_invariant_rand over full state_shape)
    # The sampler receives x_sigma_max with noise on all frames;
    # _denoise() replaces condition frames with gt internally at each step.
    # IMPORTANT: Keep x_sigma_max in float32 (NOT bf16) to match original cosmos-policy.
    # The sampler uses x_sigma_max.dtype as in_dtype for sigma precision, and
    # float32 sigma is needed for accurate EDM preconditioning (c_in, c_noise).
    from AlphaBrain.model.modules.world_model.cosmos.noise_utils import arch_invariant_rand
    x_sigma_max = arch_invariant_rand(
        (B_lat, C, T_full, H_lat, W_lat), seed=1
    ).to(device=device, dtype=torch.float32) * self.inference_sigma_max

    # Get T5 embeddings
    crossattn_emb = self._get_t5_embeddings_for_inference(instructions, device)

    # --- Step 5: Multi-step denoising ---
    # padding_mask in pixel space (224x224), matching original cosmos_utils.py
    padding_mask = torch.zeros(B, 1, 224, 224, device=device, dtype=x0.dtype)
    fps = torch.tensor([16] * B, device=device, dtype=torch.float32)

    def x0_fn(xt, sigma):
        return self._denoise(
            xt, sigma, crossattn_emb, fps, padding_mask,
            gt_frames=gt_frames, n_cond_frames=n_cond
        )

    denoised = self.sampler.forward(
        x0_fn=x0_fn,
        x_sigma_max=x_sigma_max,  # full 9-frame noise (matching original)
        num_steps=self.action_num_steps,
        sigma_min=self.inference_sigma_min,
        sigma_max=self.inference_sigma_max,
    )
    # --- Step 6: Extract action from denoised latent ---
    action_latent = denoised[:, :, self.action_idx, :, :]  # (B, 16, 28, 28)
    action_chunk = self._extract_action_from_latent(action_latent)  # (B, chunk_size, action_dim)

    return {
        "normalized_actions": action_chunk.float().cpu().numpy(),
        "denoised_latent": denoised,
        "x0_latent": x0,
    }
set_dataset_stats
set_dataset_stats(dataset_stats: dict)

Store dataset statistics for proprio normalization during inference.

Parameters:

Name Type Description Default
dataset_stats dict

dict with keys 'proprio_min', 'proprio_max' (np.ndarray).

required
Source code in AlphaBrain/model/framework/CosmosPolicy.py
def set_dataset_stats(self, dataset_stats: dict):
    """
    Store dataset statistics for proprio normalization during inference.

    Args:
        dataset_stats: dict with keys 'proprio_min', 'proprio_max' (np.ndarray).
    """
    self._dataset_stats = dataset_stats
    logger.info("[CosmosPolicy] Dataset stats set for proprio normalization")

NeuroVLA

NeuroVLA

NeuroVLA

NeuroVLA(config: Optional[dict] = None, norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]] = None, **kwargs)

Bases: BaseFramework

NeuroVLA: Vision-Language-Action model for robotic manipulation.

This model combines a vision-language model (Qwen-VL) with action prediction to generate robot actions from visual observations and language instructions.

Source code in AlphaBrain/model/framework/NeuroVLA.py
def __init__(
    self,
    config: Optional[dict] = None,
    norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]] = None,
    **kwargs,
) -> None:
    super().__init__()
    self.config = config

    # Vision-language model for processing images and instructions
    self.qwen_vl_interface = get_vlm_model(config=self.config)

    # Q-Former for extracting action-relevant features from VLM hidden states
    self.layer_qformer = get_layerwise_qformer(config=self.config)

    # Action prediction model (input_dim=768, hidden_dim=1536, action_dim=7)
    self.action_model = get_action_model(input_dim=768, hidden_dim=768*2, action_dim=7)

    # Edit model for refining actions based on robot states
    self.edit_model = get_gruedit_model(input_dim=768, hidden_dim=256, robot_state_dim=8)

    self.l1_loss = nn.L1Loss()
    self.norm_stats = norm_stats
forward
forward(examples: List[dict] = None, repeated_diffusion_steps: int = 4, **kwargs) -> Tuple

Run a forward pass through the VLM, returning loss for training.

Parameters:

Name Type Description Default
examples List[dict]

List of training examples, each containing: - "image": Input images - "lang": Language instructions - "action": Ground truth actions [B, T, 7] - "state": Robot states [B, T, 8] - "solution" (optional): Chain-of-thought solutions

None

Returns:

Type Description
Tuple

Dictionary containing action_loss

Source code in AlphaBrain/model/framework/NeuroVLA.py
def forward(
    self,
    examples: List[dict] = None,
    repeated_diffusion_steps: int = 4,
    **kwargs,
) -> Tuple:
    """
    Run a forward pass through the VLM, returning loss for training.

    Args:
        examples: List of training examples, each containing:
            - "image": Input images
            - "lang": Language instructions
            - "action": Ground truth actions [B, T, 7]
            - "state": Robot states [B, T, 8]
            - "solution" (optional): Chain-of-thought solutions

    Returns:
        Dictionary containing action_loss
    """
    inference_num = 0

    # Extract data from examples
    images = [example["image"] for example in examples]
    instructions = [example["lang"] for example in examples]
    actions = [example["action"] for example in examples]
    assert "state" in examples[0], (
        "NeuroVLA requires 'state' in training samples. "
        "Please set 'include_state: true' in your dataset config yaml (datasets.vla_data.include_state)."
    )
    states = [example["state"] for example in examples]

    if "solution" in examples[0]:
        solutions = [example["solution"] for example in examples]
    else:
        solutions = None

    # Build inputs for vision-language model
    qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(
        images=images, instructions=instructions, solutions=solutions
    )

    # Forward pass through VLM to get hidden states
    with torch.autocast("cuda", dtype=torch.bfloat16):
        qwenvl_outputs = self.qwen_vl_interface(
            **qwen_inputs,
            output_attentions=False,
            output_hidden_states=True,
            return_dict=True,
        )

    vlm_cot_loss = qwenvl_outputs.loss

    if vlm_cot_loss is None or torch.isnan(vlm_cot_loss):
        vlm_cot_loss = torch.tensor(0.0, device=self.qwen_vl_interface.model.device)

    # Action prediction with iterative refinement
    with torch.autocast("cuda", dtype=torch.float32):
        # Extract action-relevant features from VLM hidden states
        start_layer = self.config.framework.layer_qformer.qformer_start_layer if self.config else -6
        end_layer = self.config.framework.layer_qformer.qformer_end_layer if self.config else -1
        action_latent_feature = self.layer_qformer(qwenvl_outputs.hidden_states[start_layer:end_layer])

        states = torch.tensor(np.array(states), dtype=torch.float32, device=action_latent_feature.device)
        all_predicted_actions = []
        inference_num = 0

        # Compute number of iterations based on action horizon and chunk size
        action_horizon = np.array(actions).shape[1]  # total action steps from ground truth
        chunk_size = self.layer_qformer.num_query_tokens  # steps predicted per iteration
        num_iterations = max(1, action_horizon // chunk_size)

        # Iterative action prediction
        while inference_num < num_iterations:
            # Edit action features based on current robot states
            edit_action_feature = self.edit_model(action_latent_feature, states)

            # Predict action chunk
            predicted_actions = self.action_model.predict_action(edit_action_feature)
            all_predicted_actions.append(predicted_actions)

            # Update states for next iteration
            predicted_states = torch.zeros_like(states)
            predicted_states[:, :predicted_actions.shape[1], :7] = predicted_actions
            predicted_states[:, :, 7] = states[:, :, 7]  # Keep gripper state
            states = predicted_states.clone()
            inference_num += 1

        # Compute action loss
        action_tensor = torch.tensor(np.array(actions), dtype=torch.float32, device=predicted_actions.device)
        predicted_action_tensor = torch.cat(all_predicted_actions, dim=1)
        action_loss = self.l1_loss(predicted_action_tensor, action_tensor)

    return {"action_loss": action_loss}
predict_action
predict_action(batch_images: Union[Image, List[Image]], instructions: List[str], states: Optional[List[Sequence[float]]] = None, solutions: Union[Dict, List[Dict]] = None, unnorm_key: Optional[str] = None, cfg_scale: float = 1.5, use_ddim: bool = False, num_ddim_steps: int = 5, **kwargs: str) -> np.ndarray

Predict action from images and instructions.

Parameters:

Name Type Description Default
batch_images Union[Image, List[Image]]

Input images (PIL Image or list of PIL Images)

required
instructions List[str]

Task instructions (list of strings)

required
states Optional[List[Sequence[float]]]

Robot states history [B, T, 8], where last dim is [x,y,z,roll,pitch,yaw,gripper,pad]

None
solutions Union[Dict, List[Dict]]

Optional solution dict for chain-of-thought

None
unnorm_key Optional[str]

Key for unnormalization (if using norm_stats)

None
cfg_scale float

Classifier-free guidance scale (>1.0 enables CFG)

1.5
use_ddim bool

Whether to use DDIM sampling

False
num_ddim_steps int

Number of DDIM steps

5

Returns:

Type Description
ndarray

Dictionary containing "normalized_actions" [B, T, 7]

Source code in AlphaBrain/model/framework/NeuroVLA.py
@torch.inference_mode()
def predict_action(
    self,
    batch_images: Union[Image, List[Image]],
    instructions: List[str],
    states: Optional[List[Sequence[float]]] = None,
    solutions: Union[Dict, List[Dict]] = None,
    unnorm_key: Optional[str] = None,
    cfg_scale: float = 1.5,
    use_ddim: bool = False,
    num_ddim_steps: int = 5,
    **kwargs: str
) -> np.ndarray:
    """
    Predict action from images and instructions.

    Args:
        batch_images: Input images (PIL Image or list of PIL Images)
        instructions: Task instructions (list of strings)
        states: Robot states history [B, T, 8], where last dim is [x,y,z,roll,pitch,yaw,gripper,pad]
        solutions: Optional solution dict for chain-of-thought
        unnorm_key: Key for unnormalization (if using norm_stats)
        cfg_scale: Classifier-free guidance scale (>1.0 enables CFG)
        use_ddim: Whether to use DDIM sampling
        num_ddim_steps: Number of DDIM steps

    Returns:
        Dictionary containing "normalized_actions" [B, T, 7]
    """
    predict_num = 0

    # ! [zhanghe] 将client端的array转化为PIL; 后续考虑在其他地方处理;
    if isinstance(batch_images[0][0], np.ndarray):
        batch_images = [[Image.fromarray(img) for img in seq] for seq in batch_images]

    batch_images = resize_images(batch_images, target_size=(224, 224))

    # Build VLM inputs
    interface_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
    qwen_inputs = interface_inputs

    all_predicted_actions = []

    # Generate cognition features through VLM
    with torch.autocast("cuda", dtype=torch.bfloat16):
        qwenvl_outputs = self.qwen_vl_interface(
            input_ids=qwen_inputs.input_ids,
            attention_mask=qwen_inputs.attention_mask,
            pixel_values=qwen_inputs.pixel_values,
            image_grid_thw=qwen_inputs.image_grid_thw,
            labels=qwen_inputs.input_ids.clone(),
            output_hidden_states=True,
            return_dict=True,
        )

    # Action prediction with iterative refinement
    with torch.autocast("cuda", dtype=torch.float32):
        # Extract action features from VLM hidden states
        start_layer = self.config.framework.layer_qformer.qformer_start_layer if self.config else -2
        end_layer = self.config.framework.layer_qformer.qformer_end_layer if self.config else -1

        action_latent_feature = self.layer_qformer(qwenvl_outputs.hidden_states[start_layer:end_layer])

        using_cfg = cfg_scale > 1.0
        B = action_latent_feature.shape[0]

        # Convert states to tensor
        states = torch.tensor(
            np.array(states, dtype=np.float32),
            dtype=torch.float32,
            device=action_latent_feature.device
        )

        # Iterative action prediction
        # Use num_query_tokens as action_horizon (matches training: action_horizon // chunk_size = 1)
        action_horizon = self.layer_qformer.num_query_tokens
        num_iterations = max(1, action_horizon // self.layer_qformer.num_query_tokens)
        while predict_num < num_iterations:
            # Edit action features based on current states
            edit_action_feature = self.edit_model(action_latent_feature, states)

            # Predict action chunk
            samples = self.action_model.predict_action(edit_action_feature)
            all_predicted_actions.append(samples)

            # Update states for next iteration
            predicted_states = torch.zeros_like(states)
            predicted_states[:, :samples.shape[1], :7] = samples
            predicted_states[:, :, 7] = states[:, :, 7]  # Keep gripper state
            states = predicted_states.clone()
            predict_num += 1

    # Concatenate all predicted action chunks
    predicted_action_tensor = torch.cat(all_predicted_actions, dim=1)
    normalized_actions = predicted_action_tensor.detach().cpu().numpy()
    return {"normalized_actions": normalized_actions}

build_model_framework

build_model_framework(config: dict = {}) -> NeuroVLA

Build NeuroVLA model from config.

Source code in AlphaBrain/model/framework/NeuroVLA.py
def build_model_framework(config: dict = {}) -> NeuroVLA:
    """Build NeuroVLA model from config."""
    model = NeuroVLA(config=config)
    return model

PaliGemma family

PaliGemmaOFT

PaliGemma-OFT Framework

Uses PaliGemma (SigLIP + Gemma 2B) as VLM backbone with action special token for continuous action prediction via L1 regression. Mirrors LlamaOFT / QwenOFT architecture but with PaliGemma backbone.

PaliGemma_OFT

PaliGemma_OFT(config: Optional[dict] = None, **kwargs)

Bases: BaseFramework

PaliGemma + action token OFT framework. Predicts continuous actions via L1 regression on action token hidden states.

Source code in AlphaBrain/model/framework/PaliGemmaOFT.py
def __init__(self, config: Optional[dict] = None, **kwargs) -> None:
    super().__init__()
    self.config = config

    # Use PaliGemmaOFT-specific VLM interface (not get_vlm_model which routes to Pi0 version)
    self.paligemma_vl_interface = _PaliGemma_OFT_VL_Interface(config=self.config)

    # Align action hidden dim with LLM hidden size (Gemma 2B: 2048)
    config.framework.action_model.action_hidden_dim = self.paligemma_vl_interface.model.config.text_config.hidden_size
    self.action_model = get_action_model(config=self.config)

    # Enable gradient checkpointing for memory efficiency
    if hasattr(self.paligemma_vl_interface.model, 'gradient_checkpointing_enable'):
        self.paligemma_vl_interface.model.gradient_checkpointing_enable()
        logger.info("Enabled gradient checkpointing for PaliGemma model")

    self.future_action_window_size = config.framework.action_model.future_action_window_size
    self.past_action_window_size = config.framework.action_model.past_action_window_size
    self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size

    self.action_token = "🔍"
    self.action_token_id = self.paligemma_vl_interface.processor.tokenizer(
        "🔍", add_special_tokens=False
    )["input_ids"][0]

    self.l1_loss = nn.L1Loss()
forward
forward(examples: List[dict] = None, **kwargs) -> Tuple

Training forward: L1 regression on action tokens.

Source code in AlphaBrain/model/framework/PaliGemmaOFT.py
def forward(self, examples: List[dict] = None, **kwargs) -> Tuple:
    """Training forward: L1 regression on action tokens."""
    batch_images = [example["image"] for example in examples]
    instructions = [example["lang"] for example in examples]
    actions = [example["action"] for example in examples]

    # Structured prompt for multi-task instruction routing
    instructions = [self._build_structured_prompt(inst) for inst in instructions]

    # Build PaliGemma inputs
    paligemma_inputs = self.paligemma_vl_interface.build_paligemma_inputs(
        images=batch_images, instructions=instructions
    )

    # Move to device
    device = next(self.parameters()).device
    paligemma_inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in paligemma_inputs.items()}

    outputs = self.paligemma_vl_interface(
        **paligemma_inputs,
        output_attentions=False,
        output_hidden_states=True,
        return_dict=True,
    )
    last_hidden = outputs.hidden_states[-1]

    with torch.autocast("cuda", dtype=torch.float32):
        input_ids = paligemma_inputs.get("input_ids", None)
        action_queries = self._gather_action_token_embeddings(
            last_hidden, input_ids, action_token_id=self.action_token_id
        )
        pred_actions = self.action_model.predict_action(action_queries)

        actions = torch.tensor(
            np.array(actions), device=pred_actions.device, dtype=pred_actions.dtype
        )
        actions_target = actions[:, -(self.future_action_window_size + 1):, :]

        action_loss = self.l1_loss(pred_actions, actions_target)

    return {"action_loss": action_loss}
predict_action
predict_action(batch_images: List = None, instructions: List[str] = None, examples: List[dict] = None, **kwargs) -> np.ndarray

Inference: predict normalized actions.

Source code in AlphaBrain/model/framework/PaliGemmaOFT.py
@torch.inference_mode()
def predict_action(
    self,
    batch_images: List = None,
    instructions: List[str] = None,
    examples: List[dict] = None,
    **kwargs,
) -> np.ndarray:
    """Inference: predict normalized actions."""
    if examples is not None:
        batch_images = [to_pil_preserve(example["image"]) for example in examples]
        instructions = [example["lang"] for example in examples]
    else:
        batch_images = [to_pil_preserve(imgs) for imgs in batch_images]

    train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
    if train_obs_image_size:
        batch_images = resize_images(batch_images, target_size=train_obs_image_size)

    instructions = [self._build_structured_prompt(inst) for inst in instructions]

    paligemma_inputs = self.paligemma_vl_interface.build_paligemma_inputs(
        images=batch_images, instructions=instructions
    )

    device = next(self.parameters()).device
    paligemma_inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in paligemma_inputs.items()}

    outputs = self.paligemma_vl_interface(
        **paligemma_inputs,
        output_attentions=False,
        output_hidden_states=True,
        return_dict=True,
    )
    last_hidden = outputs.hidden_states[-1]

    with torch.autocast("cuda", dtype=torch.float32):
        input_ids = paligemma_inputs.get("input_ids", None)
        action_queries = self._gather_action_token_embeddings(
            last_hidden, input_ids, action_token_id=self.action_token_id
        )
        pred_actions = self.action_model.predict_action(action_queries)

    normalized_actions = pred_actions.detach().cpu().numpy()
    return {"normalized_actions": normalized_actions}

PaliGemmaPi0

PaliGemmaOFT Framework

Integrates the π₀/π₀.₅ flow matching architecture into VLA-Engine. Key innovation: the VLM backbone is swappable — you can use PaliGemma (original), Qwen2.5-VL, Llama 3.2 Vision, or any future VLM backend.

Architecture

VLM (any) → prefix embedding → [KV cache] → Action Expert (Gemma) + Flow Matching → actions

Components
  • VLM interface: reuses AlphaBrain's existing get_vlm_model() factory
  • Action Expert: independent Gemma transformer (from openpi)
  • Flow Matching Head: multi-step denoising action generation

Training: flow matching loss (MSE between predicted and target velocity fields) Inference: iterative denoising from Gaussian noise (default 10 steps)

PaliGemma_OFT

PaliGemma_OFT(config: Optional[dict] = None, **kwargs)

Bases: BaseFramework

Pi0/Pi0.5 framework with swappable VLM backbone.

Config structure

framework: name: PaliGemmaOFT pi05: true # true for π₀.₅, false for π₀ paligemma: # or qwenvl/llamavl — uses get_vlm_model() base_vlm: google/paligemma-3b-pt-224 action_expert: width: 1024 depth: 18 ... action_model: action_dim: 7 action_horizon: 50 num_inference_steps: 10

Source code in AlphaBrain/model/framework/PaliGemmaPi0.py
def __init__(self, config: Optional[dict] = None, **kwargs) -> None:
    super().__init__()
    self.config = config

    pi05 = getattr(config.framework, 'pi05', True)
    self.pi05 = pi05

    # ── VLM backbone (swappable) ──
    # Determine which VLM to use based on config
    vlm_type = self._detect_vlm_type(config)

    if vlm_type == "paligemma":
        from AlphaBrain.model.modules.vlm.paligemma import _PaliGemma_VL_Interface
        self.vlm_interface = _PaliGemma_VL_Interface(config=config)
    else:
        # Use existing VLM factory for Qwen/Llama/Florence etc.
        self.vlm_interface = get_vlm_model(config=config)

    # ── Action Expert + Flow Matching Head ──
    expert_cfg = config.framework.action_expert
    action_cfg = config.framework.action_model

    # Determine expert type based on framework name and config
    expert_type = "gemma"  # default
    if config.framework.name == "LlamaPi0":
        expert_type = "llama"
    elif hasattr(expert_cfg, 'type'):
        expert_type = expert_cfg.type

    self._tokenizer = None
    self.flow_matching_head = Pi0FlowMatchingHead(
        action_dim=action_cfg.action_dim,
        action_horizon=action_cfg.action_horizon,
        action_expert_width=getattr(expert_cfg, 'width', 1024),
        action_expert_depth=getattr(expert_cfg, 'depth', 18),
        action_expert_mlp_dim=getattr(expert_cfg, 'mlp_dim', 4096),
        action_expert_num_heads=getattr(expert_cfg, 'num_heads', 8 if expert_type == "gemma" else 32),
        action_expert_num_kv_heads=getattr(expert_cfg, 'num_kv_heads', 1 if expert_type == "gemma" else 8),
        action_expert_head_dim=getattr(expert_cfg, 'head_dim', 256 if expert_type == "gemma" else 128),
        pi05=pi05,
        precision=getattr(expert_cfg, 'precision', 'bfloat16'),
        num_inference_steps=getattr(action_cfg, 'num_inference_steps', 10),
        noise_beta_alpha=getattr(action_cfg, 'noise_beta_alpha', 1.5),
        noise_beta_beta=getattr(action_cfg, 'noise_beta_beta', 1.0),
        expert_type=expert_type,
        state_dim=getattr(action_cfg, 'state_dim', None),
    )

    # ── Prefix projection (for VLM hidden_size != action_expert_width) ──
    expert_width = getattr(expert_cfg, 'width', 1024)
    vlm_hidden_size = self._get_vlm_hidden_size()
    if vlm_hidden_size is not None and vlm_hidden_size != expert_width and vlm_type != "paligemma":
        # PaliGemma uses its own encode_prefix which outputs expert_width directly
        # For Qwen/Llama, VLM hidden states need projection to match action expert
        self.prefix_proj = nn.Linear(vlm_hidden_size, expert_width, bias=False)
        logger.info(f"[prefix_proj] Added projection: VLM hidden={vlm_hidden_size} → expert_width={expert_width}")
    else:
        self.prefix_proj = None

    # ── Action dimension settings ──
    self.action_dim = action_cfg.action_dim
    self.action_horizon = action_cfg.action_horizon
    self.future_action_window_size = getattr(action_cfg, 'future_action_window_size', action_cfg.action_horizon)
    self.past_action_window_size = getattr(action_cfg, 'past_action_window_size', 0)
    self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size

    # Store VLM rotary_emb reference for inference consistency
    vlm_lm = self._get_vlm_language_model()
    if vlm_lm is not None and hasattr(vlm_lm, 'rotary_emb'):
        self.flow_matching_head._vlm_rotary_emb = vlm_lm.rotary_emb

    logger.info(
        f"PaliGemmaOFT initialized: pi05={pi05}, vlm={vlm_type}, "
        f"action_dim={self.action_dim}, horizon={self.action_horizon}"
    )

    # Enable gradient checkpointing on action expert only.
    # _shared_forward handles VLM layers via use_gc (compute_layer checkpoint).
    # Enabling HF-level GC on VLM GemmaModel would cause double-checkpoint conflict
    # with DeepSpeed ZeRO: same parameters reduced twice -> AssertionError.
    for m in self.flow_matching_head.modules():
        if hasattr(m, "gradient_checkpointing_enable"):
            m.gradient_checkpointing_enable()
    # Disable HF-level GC on VLM language model to prevent double-backward
    try:
        vlm_lm = self._get_vlm_language_model()
        if hasattr(vlm_lm, "gradient_checkpointing_disable"):
            vlm_lm.gradient_checkpointing_disable()
    except Exception:
        pass

    # Freeze VLM modules based on config (empty string = no freeze = full finetune)
    freeze_modules = getattr(config.trainer, 'freeze_modules', 'vlm_interface.model.language_model,vlm_interface.model.lm_head')
    if freeze_modules:
        freeze_list = [m.strip() for m in str(freeze_modules).split(',') if m.strip()]
        if hasattr(self.vlm_interface, "model"):
            for name, param in self.vlm_interface.model.named_parameters():
                if any(fm.replace('vlm_interface.model.', '') in name for fm in freeze_list):
                    param.requires_grad = False
        frozen_p = sum(p.numel() for p in self.parameters() if not p.requires_grad)
        train_p = sum(p.numel() for p in self.parameters() if p.requires_grad)
        logger.info(f"Frozen: {frozen_p/1e6:.0f}M, Trainable: {train_p/1e6:.0f}M (freeze={freeze_modules})")
    else:
        total_p = sum(p.numel() for p in self.parameters())
        logger.info(f"Full finetune: all {total_p/1e6:.0f}M parameters trainable")

    # ── Action/State Normalization (MEAN_STD, matches openpi) ──
    norm_cfg = getattr(config.framework, 'normalization', None)
    self.use_action_norm = norm_cfg is not None and getattr(norm_cfg, 'enabled', False)
    if self.use_action_norm:
        import json as _json
        action_mean = torch.tensor(getattr(norm_cfg, 'action_mean', [0.0]*self.action_dim), dtype=torch.float32)
        action_std = torch.tensor(getattr(norm_cfg, 'action_std', [1.0]*self.action_dim), dtype=torch.float32)
        self.register_buffer('action_mean', action_mean)
        self.register_buffer('action_std', action_std)
        if hasattr(norm_cfg, 'state_mean'):
            state_mean = torch.tensor(norm_cfg.state_mean, dtype=torch.float32)
            state_std = torch.tensor(norm_cfg.state_std, dtype=torch.float32)
            self.register_buffer('state_mean', state_mean)
            self.register_buffer('state_std', state_std)
        logger.info(f"[norm] Action MEAN_STD normalization enabled (action_dim={self.action_dim})")
    else:
        logger.info("[norm] No action normalization (raw actions)")
forward
forward(examples: List[dict] = None, **kwargs) -> Tuple

Training forward pass.

Parameters:

Name Type Description Default
examples List[dict]

list of dicts with keys: image, lang, action, (state)

None

Returns:

Type Description
Tuple

(loss, metrics_dict)

Source code in AlphaBrain/model/framework/PaliGemmaPi0.py
def forward(self, examples: List[dict] = None, **kwargs) -> Tuple:
    """
    Training forward pass.

    Args:
        examples: list of dicts with keys: image, lang, action, (state)

    Returns:
        (loss, metrics_dict)
    """
    actions = torch.stack([torch.tensor(ex["action"], dtype=torch.float32) for ex in examples])
    actions = actions.to(next(self.parameters()).device)

    # Normalize actions (MEAN_STD) if enabled
    if self.use_action_norm:
        actions = (actions - self.action_mean.to(actions.device)) / (self.action_std.to(actions.device) + 1e-8)

    # Pad action feature dim to model action_dim (e.g., robot 8-dim → base model 32-dim)
    if actions.shape[-1] < self.action_dim:
        actions = torch.nn.functional.pad(actions, (0, self.action_dim - actions.shape[-1]))

    # Truncate/pad actions to action_horizon
    if actions.shape[1] > self.action_horizon:
        actions = actions[:, :self.action_horizon]
    elif actions.shape[1] < self.action_horizon:
        pad = torch.zeros(
            actions.shape[0], self.action_horizon - actions.shape[1], actions.shape[2],
            device=actions.device, dtype=actions.dtype
        )
        actions = torch.cat([actions, pad], dim=1)

    # Get state if available (None for π₀.₅ with discrete state)
    state = None
    if not self.pi05 and "state" in examples[0]:
        state = torch.stack([torch.tensor(ex["state"], dtype=torch.float32) for ex in examples])
        state = state.to(actions.device)
        # Pad state feature dim to model action_dim
        if state.shape[-1] < self.action_dim:
            state = torch.nn.functional.pad(state, (0, self.action_dim - state.shape[-1]))

    # Encode prefix
    prefix_embs, prefix_pad_masks, prefix_att_masks = self._prepare_prefix(examples)

    # Check VLM type and framework to determine forward path
    vlm_type = self._detect_vlm_type(self.config)
    framework_name = self.config.framework.name

    if framework_name == "LlamaPi0" and vlm_type == "llama":
        # LlamaPi0: Need VLM input embeddings (not full forward output)
        # because _shared_forward_llama handles layer-by-layer processing
        prefix_embs_raw, prefix_pad_masks_raw, prefix_att_masks_raw = self._prepare_prefix_llama_raw(examples)
        vlm_lm = self._get_vlm_language_model()
        loss = self.flow_matching_head.compute_loss_llama(
            prefix_embs=prefix_embs_raw,
            prefix_pad_masks=prefix_pad_masks_raw,
            prefix_att_masks=prefix_att_masks_raw,
            vlm_language_model=vlm_lm,
            state=state,
            actions=actions,
        )
    elif vlm_type == "paligemma":
        # Traditional joint attention path for PaliGemma
        vlm_lm = self._get_vlm_language_model()
        loss = self.flow_matching_head.compute_loss(
            prefix_embs=prefix_embs,
            prefix_pad_masks=prefix_pad_masks,
            prefix_att_masks=prefix_att_masks,
            vlm_language_model=vlm_lm,
            state=state,
            actions=actions,
        )
    else:
        # Prefix cache path for non-PaliGemma VLMs (Llama, etc.)
        loss = self.flow_matching_head.compute_loss_prefix_cache(
            prefix_embs=prefix_embs,
            prefix_pad_masks=prefix_pad_masks,
            prefix_att_masks=prefix_att_masks,
            state=state,
            actions=actions,
        )

    loss_mean = loss.mean()
    return {"action_loss": loss_mean, "flow_matching_loss": loss_mean.item()}
predict_action
predict_action(batch_images: List = None, instructions: List[str] = None, examples: List[dict] = None, unnorm_key=None, **kwargs)

Inference: predict actions via multi-step denoising.

Returns:

Type Description

np.ndarray: [B, action_horizon, action_dim] unnormalized actions

Source code in AlphaBrain/model/framework/PaliGemmaPi0.py
@torch.no_grad()
@torch.amp.autocast('cuda', dtype=torch.bfloat16)
def predict_action(self, batch_images: List = None, instructions: List[str] = None,
                   examples: List[dict] = None, unnorm_key=None, **kwargs):
    """
    Inference: predict actions via multi-step denoising.

    Returns:
        np.ndarray: [B, action_horizon, action_dim] unnormalized actions
    """
    # CRITICAL: disable gradient checkpointing for inference
    # GC is incompatible with KV cache (corrupts cached key/values)
    for m in self.modules():
        if hasattr(m, 'gradient_checkpointing'):
            m.gradient_checkpointing = False

    # Support both flat format (batch_images/instructions) and legacy examples format
    if examples is None and batch_images is not None:
        from PIL import Image
        states = kwargs.get("states", None)
        examples = []
        for i, imgs in enumerate(batch_images):
            # Pass all views as list for multi-view support
            img = list(imgs) if isinstance(imgs, (list, tuple)) else imgs
            lang = instructions[i] if instructions else ""
            ex = {"image": img, "lang": lang}
            if states is not None and i < len(states):
                ex["state"] = states[i]
            examples.append(ex)

    device = next(self.parameters()).device

    state = None
    if not self.pi05 and "state" in examples[0]:
        state = torch.stack([torch.tensor(ex["state"], dtype=torch.float32) for ex in examples])
        state = state.to(device)

    # Check VLM type and framework to determine inference path
    vlm_type = self._detect_vlm_type(self.config)
    framework_name = self.config.framework.name

    if framework_name == "LlamaPi0" and vlm_type == "llama":
        # LlamaPi0: Llama VLM + Llama Action Expert with joint attention
        prefix_embs, prefix_pad_masks, prefix_att_masks = self._prepare_prefix_generic(examples)
        vlm_lm = self._get_vlm_language_model()
        actions = self.flow_matching_head.sample_actions_llama(
            prefix_embs=prefix_embs,
            prefix_pad_masks=prefix_pad_masks,
            prefix_att_masks=prefix_att_masks,
            vlm_language_model=vlm_lm,
            state=state,
            device=device,
        )

        # Unnormalize actions if MEAN_STD normalization was used
        if self.use_action_norm:
            actions = actions * self.action_std.to(actions.device) + self.action_mean.to(actions.device)

        actions_np = actions.cpu().float().numpy()
        return {"normalized_actions": actions_np.tolist()}
    elif vlm_type != "paligemma":
        # Use prefix cache mode for non-PaliGemma VLMs
        prefix_embs, prefix_pad_masks, prefix_att_masks = self._prepare_prefix_generic(examples)
        actions = self.flow_matching_head.sample_actions_prefix_cache(
            prefix_embs=prefix_embs,
            prefix_pad_masks=prefix_pad_masks,
            prefix_att_masks=prefix_att_masks,
            state=state,
            device=device,
        )

        # Unnormalize actions if MEAN_STD normalization was used
        if self.use_action_norm:
            actions = actions * self.action_std.to(actions.device) + self.action_mean.to(actions.device)

        actions_np = actions.cpu().float().numpy()
        return {"normalized_actions": actions_np.tolist()}

    # Traditional PaliGemma inference path below
    # Bypass _prepare_prefix — use openpi-style embed_prefix directly
    # Inline PaligemmaTokenizer logic (matches openpi's PaligemmaTokenizer)
    # to avoid importing openpi which pulls in jax as a top-level dependency.
    import torchvision.transforms.functional as TF
    import numpy as np_
    import math as _math

    # Ensure tokenizer is initialized (reuse existing _init_tokenizer)
    if self._tokenizer is None and not hasattr(self, '_hf_tokenizer'):
        self._init_tokenizer()

    _PREDICT_MAX_LEN = 200  # match openpi PaligemmaTokenizer default

    def _tokenize_openpi_style(text, max_len=_PREDICT_MAX_LEN):
        """Tokenize text in openpi PaligemmaTokenizer format: BOS + cleaned_text + newline, padded to max_len."""
        cleaned = str(text).strip().replace("_", " ").replace("\n", " ")
        if hasattr(self, '_hf_tokenizer') and self._hf_tokenizer is not None:
            bos_id = self._hf_tokenizer.bos_token_id
            text_ids = self._hf_tokenizer.encode(cleaned, add_special_tokens=False)
            newline_ids = self._hf_tokenizer.encode("\n", add_special_tokens=False)
            ids = ([bos_id] if bos_id is not None else []) + text_ids + newline_ids
        else:
            ids = self._tokenizer.encode(cleaned, add_bos=True) + self._tokenizer.encode("\n")
        tokens_len = len(ids)
        if tokens_len < max_len:
            mask = [True] * tokens_len + [False] * (max_len - tokens_len)
            ids = ids + [0] * (max_len - tokens_len)
        else:
            ids = ids[:max_len]
            mask = [True] * max_len
        return np_.asarray(ids), np_.asarray(mask)

    def _proc_img(im):
        t = torch.from_numpy(im.copy()).float()
        if t.ndim == 3 and t.shape[-1] == 3:
            t = t.permute(2, 0, 1)
        t = t / 255.0
        t = TF.resize(t, [224, 224], antialias=True)
        t = TF.normalize(t, mean=[0.5]*3, std=[0.5]*3)
        return t

    ex = examples[0]
    imgs_raw = ex['image'] if isinstance(ex['image'], list) else [ex['image']]
    _dtype = next(self.parameters()).dtype
    img_tensors = [_proc_img(im).unsqueeze(0).to(device).to(_dtype) for im in imgs_raw]
    while len(img_tensors) < 3:
        img_tensors.append(torch.full((1, 3, 224, 224), -1.0, device=device, dtype=_dtype))
    img_masks_list = [torch.tensor([True], device=device)] * len(imgs_raw) + \
                     [torch.tensor([False], device=device)] * (3 - len(imgs_raw))

    tokens, masks = _tokenize_openpi_style(ex['lang'])
    tokens_t = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
    masks_t = torch.tensor(masks, dtype=torch.bool).unsqueeze(0).to(device)

    embs_list, pad_list, att_list = [], [], []
    for img_t, img_m in zip(img_tensors, img_masks_list):
        img_emb = self.vlm_interface.model.get_image_features(img_t)
        bsize, n_embs = img_emb.shape[:2]
        embs_list.append(img_emb)
        pad_list.append(img_m[:, None].expand(bsize, n_embs))
        att_list += [0] * n_embs

    lang_emb = self.vlm_interface.model.embed_tokens(tokens_t)
    lang_emb = lang_emb * _math.sqrt(lang_emb.shape[-1])
    embs_list.append(lang_emb)
    pad_list.append(masks_t)
    att_list += [0] * lang_emb.shape[1]

    prefix_embs = torch.cat(embs_list, dim=1)
    prefix_pad_masks = torch.cat(pad_list, dim=1)
    att_tensor = torch.tensor(att_list, dtype=torch.bool, device=device)
    prefix_att_masks = att_tensor[None, :].expand(bsize, -1)

    # DEBUG: log prefix stats for comparison with openpi
    if not hasattr(self, '_debug_count'):
        self._debug_count = 0
    if self._debug_count < 3:
        # Log image info from examples
        ex0 = examples[0]
        img0 = ex0['image']
        if isinstance(img0, list):
            for vi, v in enumerate(img0):
                import numpy as np
                if isinstance(v, np.ndarray):
                    print(f"[DEBUG] image[{vi}]: type=ndarray, dtype={v.dtype}, shape={v.shape}, range=[{v.min()}, {v.max()}]")
                elif isinstance(v, torch.Tensor):
                    print(f"[DEBUG] image[{vi}]: type=tensor, dtype={v.dtype}, shape={v.shape}, range=[{v.min()}, {v.max()}]")
                else:
                    print(f"[DEBUG] image[{vi}]: type={type(v)}")
        print(f"[DEBUG] prefix shape={prefix_embs.shape}, "
              f"mean={prefix_embs.float().mean():.6f}, std={prefix_embs.float().std():.6f}, "
              f"img[0:4]={prefix_embs[0,0,:4].float().cpu().tolist()}, "
              f"lang[768:772]={prefix_embs[0,768,:4].float().cpu().tolist()}")
        import sys; sys.stdout.flush(); sys.stderr.flush()
        self._debug_count += 1

    vlm_lm = self._get_vlm_language_model()

    # Use openpi-style KV cache inference directly (proven to work in AB test)
    from AlphaBrain.model.modules.action_model.pi0_flow_matching_head.openpi_inference import make_att_2d_masks
    from transformers.models.gemma import modeling_gemma

    bsize = prefix_pad_masks.shape[0]

    # Step 1: prefix → KV cache through VLM language model
    prefix_att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
    prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
    prefix_att_4d = prefix_att_2d[:, None, :, :]
    prefix_att_4d = torch.where(prefix_att_4d, 0.0, -2.3819763e38)

    vlm_lm.config._attn_implementation = "eager"
    prefix_output = vlm_lm.forward(
        inputs_embeds=prefix_embs,
        attention_mask=prefix_att_4d,
        position_ids=prefix_position_ids,
        past_key_values=None,
        use_cache=True,
    )
    past_key_values = prefix_output.past_key_values
    # For transformers >= 4.45: DynamicCache accumulates K,V across denoising steps.
    # Convert to legacy tuple format so we can recreate a fresh cache each step.
    if hasattr(past_key_values, 'to_legacy_cache'):
        _vlm_kv_legacy = past_key_values.to_legacy_cache()
        _use_dynamic_cache = True
    else:
        _use_dynamic_cache = False

    # Step 2: iterative denoising through action expert
    expert_model = self.flow_matching_head.action_expert.model.model
    num_steps = self.flow_matching_head.num_inference_steps
    dt = -1.0 / num_steps
    dt_t = torch.tensor(dt, dtype=torch.float32, device=device)

    noise = torch.randn(bsize, self.action_horizon, self.action_dim, dtype=torch.float32, device=device)
    x_t = noise
    time = torch.tensor(1.0, dtype=torch.float32, device=device)

    while time >= -dt_t / 2:
        expanded_time = time.expand(bsize)
        suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.flow_matching_head.embed_suffix(
            state, x_t, expanded_time
        )

        suffix_len = suffix_pad_masks.shape[1]
        prefix_len = prefix_pad_masks.shape[1]

        prefix_pad_2d = prefix_pad_masks[:, None, :].expand(bsize, suffix_len, prefix_len)
        suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
        full_att_2d = torch.cat([prefix_pad_2d, suffix_att_2d], dim=2)
        full_att_4d = full_att_2d[:, None, :, :]
        full_att_4d = torch.where(full_att_4d, 0.0, -2.3819763e38)

        prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
        suffix_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1

        expert_model.config._attn_implementation = "eager"
        # Rebuild fresh DynamicCache each step: prevents expert K,V from
        # accumulating across denoising steps (DynamicCache is stateful in 4.45+).
        if _use_dynamic_cache:
            from transformers import DynamicCache
            _kv_for_step = DynamicCache.from_legacy_cache(_vlm_kv_legacy)
        else:
            _kv_for_step = past_key_values
        suffix_output = expert_model.forward(
            inputs_embeds=suffix_embs,
            attention_mask=full_att_4d,
            position_ids=suffix_position_ids,
            past_key_values=_kv_for_step,
            use_cache=False,
            adarms_cond=adarms_cond,
        )

        suffix_out = suffix_output.last_hidden_state
        suffix_out = suffix_out[:, -self.action_horizon:]
        suffix_out = suffix_out.to(dtype=torch.float32)
        v_t = self.flow_matching_head.action_out_proj(suffix_out)

        x_t = x_t + dt_t * v_t
        time = time + dt_t

    actions = x_t

    # Unnormalize actions if MEAN_STD normalization was used
    if self.use_action_norm:
        actions = actions * self.action_std.to(actions.device) + self.action_mean.to(actions.device)

    actions_np = actions.cpu().float().numpy()

    # Return actions (unnormalized if norm was enabled, raw otherwise)
    return {"normalized_actions": actions_np.tolist()}

    # Enable gradient checkpointing to save memory
    if hasattr(self.vlm_interface, 'model'):
        if hasattr(self.vlm_interface.model, 'gradient_checkpointing_enable'):
            self.vlm_interface.model.gradient_checkpointing_enable()
        # For PaliGemmaVLM: enable on sub-models
        if hasattr(self.vlm_interface.model, 'language_model') and hasattr(self.vlm_interface.model.language_model, 'gradient_checkpointing_enable'):
            self.vlm_interface.model.language_model.gradient_checkpointing_enable()
        if hasattr(self.vlm_interface.model, 'vision_tower') and hasattr(self.vlm_interface.model.vision_tower, 'gradient_checkpointing_enable'):
            self.vlm_interface.model.vision_tower.gradient_checkpointing_enable()
    # Freeze VLM language model and lm_head to save GPU memory
    # Only train: vision_tower + projector + action_expert + flow matching head
    if hasattr(self.vlm_interface, 'model'):
        for name, param in self.vlm_interface.model.named_parameters():
            if 'language_model' in name or 'lm_head' in name:
                param.requires_grad = False
        frozen = sum(1 for p in self.vlm_interface.model.parameters() if not p.requires_grad)
        total = sum(1 for p in self.vlm_interface.model.parameters())
        trainable_m = sum(p.numel() for p in self.vlm_interface.model.parameters() if p.requires_grad) / 1e6
        print(f'[PaliGemmaOFT] VLM: frozen {frozen}/{total} params, trainable {trainable_m:.0f}M (vision+projector)')

    if hasattr(self.flow_matching_head, 'action_expert') and hasattr(self.flow_matching_head.action_expert.model, 'gradient_checkpointing_enable'):
        self.flow_matching_head.action_expert.model.gradient_checkpointing_enable()

PaliGemmaPi05

PaliGemmaOFT Framework

Integrates the π₀/π₀.₅ flow matching architecture into VLA-Engine. Key innovation: the VLM backbone is swappable — you can use PaliGemma (original), Qwen2.5-VL, Llama 3.2 Vision, or any future VLM backend.

Architecture

VLM (any) → prefix embedding → [KV cache] → Action Expert (Gemma) + Flow Matching → actions

Components
  • VLM interface: reuses VLAE's existing get_vlm_model() factory
  • Action Expert: independent Gemma transformer (from openpi)
  • Flow Matching Head: multi-step denoising action generation

Training: flow matching loss (MSE between predicted and target velocity fields) Inference: iterative denoising from Gaussian noise (default 10 steps)

PaliGemma_Pi05

PaliGemma_Pi05(config: Optional[dict] = None, **kwargs)

Bases: BaseFramework

Pi0/Pi0.5 framework with swappable VLM backbone.

Config structure

framework: name: PaliGemmaOFT pi05: true # true for π₀.₅, false for π₀ paligemma: # or qwenvl/llamavl — uses get_vlm_model() base_vlm: google/paligemma-3b-pt-224 action_expert: width: 1024 depth: 18 ... action_model: action_dim: 7 action_horizon: 50 num_inference_steps: 10

Source code in AlphaBrain/model/framework/PaliGemmaPi05.py
def __init__(self, config: Optional[dict] = None, **kwargs) -> None:
    super().__init__()
    self.config = config

    pi05 = getattr(config.framework, 'pi05', True)
    self.pi05 = pi05

    # ── VLM backbone (swappable) ──
    # Determine which VLM to use based on config
    vlm_type = self._detect_vlm_type(config)

    if vlm_type == "paligemma":
        from AlphaBrain.model.modules.vlm.paligemma import _PaliGemma_VL_Interface
        self.vlm_interface = _PaliGemma_VL_Interface(config=config)
    else:
        # Use existing VLM factory for Qwen/Llama/Florence etc.
        self.vlm_interface = get_vlm_model(config=config)

    # ── Action Expert + Flow Matching Head ──
    expert_cfg = config.framework.action_expert
    action_cfg = config.framework.action_model

    # Determine expert type based on framework name and config
    expert_type = "gemma"  # default
    if config.framework.name == "LlamaPi0":
        expert_type = "llama"
    elif hasattr(expert_cfg, 'type'):
        expert_type = expert_cfg.type

    self._tokenizer = None
    self.flow_matching_head = Pi0FlowMatchingHead(
        action_dim=action_cfg.action_dim,
        action_horizon=action_cfg.action_horizon,
        action_expert_width=getattr(expert_cfg, 'width', 1024),
        action_expert_depth=getattr(expert_cfg, 'depth', 18),
        action_expert_mlp_dim=getattr(expert_cfg, 'mlp_dim', 4096),
        action_expert_num_heads=getattr(expert_cfg, 'num_heads', 8 if expert_type == "gemma" else 32),
        action_expert_num_kv_heads=getattr(expert_cfg, 'num_kv_heads', 1 if expert_type == "gemma" else 8),
        action_expert_head_dim=getattr(expert_cfg, 'head_dim', 256 if expert_type == "gemma" else 128),
        pi05=pi05,
        precision=getattr(expert_cfg, 'precision', 'bfloat16'),
        num_inference_steps=getattr(action_cfg, 'num_inference_steps', 10),
        noise_beta_alpha=getattr(action_cfg, 'noise_beta_alpha', 1.5),
        noise_beta_beta=getattr(action_cfg, 'noise_beta_beta', 1.0),
        expert_type=expert_type,
        state_dim=getattr(action_cfg, 'state_dim', None),
    )

    # ── Prefix projection (for VLM hidden_size != action_expert_width) ──
    expert_width = getattr(expert_cfg, 'width', 1024)
    vlm_hidden_size = self._get_vlm_hidden_size()
    if vlm_hidden_size is not None and vlm_hidden_size != expert_width and vlm_type != "paligemma":
        # PaliGemma uses its own encode_prefix which outputs expert_width directly
        # For Qwen/Llama, VLM hidden states need projection to match action expert
        self.prefix_proj = nn.Linear(vlm_hidden_size, expert_width, bias=False)
        logger.info(f"[prefix_proj] Added projection: VLM hidden={vlm_hidden_size} → expert_width={expert_width}")
    else:
        self.prefix_proj = None

    # ── Action dimension settings ──
    self.action_dim = action_cfg.action_dim
    self.action_horizon = action_cfg.action_horizon
    self.future_action_window_size = getattr(action_cfg, 'future_action_window_size', action_cfg.action_horizon)
    self.past_action_window_size = getattr(action_cfg, 'past_action_window_size', 0)
    self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size

    # Store VLM rotary_emb reference for inference consistency
    vlm_lm = self._get_vlm_language_model()
    if vlm_lm is not None and hasattr(vlm_lm, 'rotary_emb'):
        self.flow_matching_head._vlm_rotary_emb = vlm_lm.rotary_emb

    logger.info(
        f"PaliGemmaOFT initialized: pi05={pi05}, vlm={vlm_type}, "
        f"action_dim={self.action_dim}, horizon={self.action_horizon}"
    )

    # Enable gradient checkpointing on action expert only.
    # _shared_forward handles VLM layers via use_gc (compute_layer checkpoint).
    # Enabling HF-level GC on VLM GemmaModel would cause double-checkpoint conflict
    # with DeepSpeed ZeRO: same parameters reduced twice -> AssertionError.
    for m in self.flow_matching_head.modules():
        if hasattr(m, "gradient_checkpointing_enable"):
            m.gradient_checkpointing_enable()
    # Disable HF-level GC on VLM language model to prevent double-backward
    try:
        vlm_lm = self._get_vlm_language_model()
        if hasattr(vlm_lm, "gradient_checkpointing_disable"):
            vlm_lm.gradient_checkpointing_disable()
    except Exception:
        pass

    # Freeze VLM modules based on config (empty string = no freeze = full finetune)
    freeze_modules = getattr(config.trainer, 'freeze_modules', 'vlm_interface.model.language_model,vlm_interface.model.lm_head')
    if freeze_modules:
        freeze_list = [m.strip() for m in str(freeze_modules).split(',') if m.strip()]
        if hasattr(self.vlm_interface, "model"):
            for name, param in self.vlm_interface.model.named_parameters():
                if any(fm.replace('vlm_interface.model.', '') in name for fm in freeze_list):
                    param.requires_grad = False
        frozen_p = sum(p.numel() for p in self.parameters() if not p.requires_grad)
        train_p = sum(p.numel() for p in self.parameters() if p.requires_grad)
        logger.info(f"Frozen: {frozen_p/1e6:.0f}M, Trainable: {train_p/1e6:.0f}M (freeze={freeze_modules})")
    else:
        total_p = sum(p.numel() for p in self.parameters())
        logger.info(f"Full finetune: all {total_p/1e6:.0f}M parameters trainable")

    # ── Action/State Normalization (MEAN_STD, matches openpi) ──
    norm_cfg = getattr(config.framework, 'normalization', None)
    self.use_action_norm = norm_cfg is not None and getattr(norm_cfg, 'enabled', False)
    if self.use_action_norm:
        import json as _json
        action_mean = torch.tensor(getattr(norm_cfg, 'action_mean', [0.0]*self.action_dim), dtype=torch.float32)
        action_std = torch.tensor(getattr(norm_cfg, 'action_std', [1.0]*self.action_dim), dtype=torch.float32)
        self.register_buffer('action_mean', action_mean)
        self.register_buffer('action_std', action_std)
        if hasattr(norm_cfg, 'state_mean'):
            state_mean = torch.tensor(norm_cfg.state_mean, dtype=torch.float32)
            state_std = torch.tensor(norm_cfg.state_std, dtype=torch.float32)
            self.register_buffer('state_mean', state_mean)
            self.register_buffer('state_std', state_std)
        logger.info(f"[norm] Action MEAN_STD normalization enabled (action_dim={self.action_dim})")
    else:
        logger.info("[norm] No action normalization (raw actions)")
forward
forward(examples: List[dict] = None, **kwargs) -> Tuple

Training forward pass.

Parameters:

Name Type Description Default
examples List[dict]

list of dicts with keys: image, lang, action, (state)

None

Returns:

Type Description
Tuple

(loss, metrics_dict)

Source code in AlphaBrain/model/framework/PaliGemmaPi05.py
def forward(self, examples: List[dict] = None, **kwargs) -> Tuple:
    """
    Training forward pass.

    Args:
        examples: list of dicts with keys: image, lang, action, (state)

    Returns:
        (loss, metrics_dict)
    """
    actions = torch.stack([torch.tensor(ex["action"], dtype=torch.float32) for ex in examples])
    actions = actions.to(next(self.parameters()).device)

    # Normalize actions (MEAN_STD) if enabled
    if self.use_action_norm:
        actions = (actions - self.action_mean.to(actions.device)) / (self.action_std.to(actions.device) + 1e-8)

    # Pad action feature dim to model action_dim (e.g., robot 8-dim → base model 32-dim)
    if actions.shape[-1] < self.action_dim:
        actions = torch.nn.functional.pad(actions, (0, self.action_dim - actions.shape[-1]))

    # Truncate/pad actions to action_horizon
    if actions.shape[1] > self.action_horizon:
        actions = actions[:, :self.action_horizon]
    elif actions.shape[1] < self.action_horizon:
        pad = torch.zeros(
            actions.shape[0], self.action_horizon - actions.shape[1], actions.shape[2],
            device=actions.device, dtype=actions.dtype
        )
        actions = torch.cat([actions, pad], dim=1)

    # Get state if available (None for π₀.₅ with discrete state)
    state = None
    if not self.pi05 and "state" in examples[0]:
        state = torch.stack([torch.tensor(ex["state"], dtype=torch.float32) for ex in examples])
        state = state.to(actions.device)
        # Pad state feature dim to model action_dim
        if state.shape[-1] < self.action_dim:
            state = torch.nn.functional.pad(state, (0, self.action_dim - state.shape[-1]))

    # Encode prefix
    prefix_embs, prefix_pad_masks, prefix_att_masks = self._prepare_prefix(examples)

    # Check VLM type and framework to determine forward path
    vlm_type = self._detect_vlm_type(self.config)
    framework_name = self.config.framework.name

    if framework_name == "LlamaPi0" and vlm_type == "llama":
        # LlamaPi0: Need VLM input embeddings (not full forward output)
        # because _shared_forward_llama handles layer-by-layer processing
        prefix_embs_raw, prefix_pad_masks_raw, prefix_att_masks_raw = self._prepare_prefix_llama_raw(examples)
        vlm_lm = self._get_vlm_language_model()
        loss = self.flow_matching_head.compute_loss_llama(
            prefix_embs=prefix_embs_raw,
            prefix_pad_masks=prefix_pad_masks_raw,
            prefix_att_masks=prefix_att_masks_raw,
            vlm_language_model=vlm_lm,
            state=state,
            actions=actions,
        )
    elif vlm_type == "paligemma":
        # Traditional joint attention path for PaliGemma
        vlm_lm = self._get_vlm_language_model()
        loss = self.flow_matching_head.compute_loss(
            prefix_embs=prefix_embs,
            prefix_pad_masks=prefix_pad_masks,
            prefix_att_masks=prefix_att_masks,
            vlm_language_model=vlm_lm,
            state=state,
            actions=actions,
        )
    else:
        # Prefix cache path for non-PaliGemma VLMs (Llama, etc.)
        loss = self.flow_matching_head.compute_loss_prefix_cache(
            prefix_embs=prefix_embs,
            prefix_pad_masks=prefix_pad_masks,
            prefix_att_masks=prefix_att_masks,
            state=state,
            actions=actions,
        )

    loss_mean = loss.mean()
    return {"action_loss": loss_mean, "flow_matching_loss": loss_mean.item()}
predict_action
predict_action(batch_images: List = None, instructions: List[str] = None, examples: List[dict] = None, unnorm_key=None, **kwargs)

Inference: predict actions via multi-step denoising.

Returns:

Type Description

np.ndarray: [B, action_horizon, action_dim] unnormalized actions

Source code in AlphaBrain/model/framework/PaliGemmaPi05.py
@torch.no_grad()
@torch.amp.autocast('cuda', dtype=torch.bfloat16)
def predict_action(self, batch_images: List = None, instructions: List[str] = None,
                   examples: List[dict] = None, unnorm_key=None, **kwargs):
    """
    Inference: predict actions via multi-step denoising.

    Returns:
        np.ndarray: [B, action_horizon, action_dim] unnormalized actions
    """
    # CRITICAL: disable gradient checkpointing for inference
    # GC is incompatible with KV cache (corrupts cached key/values)
    for m in self.modules():
        if hasattr(m, 'gradient_checkpointing'):
            m.gradient_checkpointing = False

    # Support both flat format (batch_images/instructions) and legacy examples format
    if examples is None and batch_images is not None:
        from PIL import Image
        states = kwargs.get("states", None)
        examples = []
        for i, imgs in enumerate(batch_images):
            # Pass all views as list for multi-view support
            img = list(imgs) if isinstance(imgs, (list, tuple)) else imgs
            lang = instructions[i] if instructions else ""
            ex = {"image": img, "lang": lang}
            if states is not None and i < len(states):
                ex["state"] = states[i]
            examples.append(ex)

    device = next(self.parameters()).device

    state = None
    if not self.pi05 and "state" in examples[0]:
        state = torch.stack([torch.tensor(ex["state"], dtype=torch.float32) for ex in examples])
        state = state.to(device)

    # Check VLM type and framework to determine inference path
    vlm_type = self._detect_vlm_type(self.config)
    framework_name = self.config.framework.name

    if framework_name == "LlamaPi0" and vlm_type == "llama":
        # LlamaPi0: Llama VLM + Llama Action Expert with joint attention
        prefix_embs, prefix_pad_masks, prefix_att_masks = self._prepare_prefix_generic(examples)
        vlm_lm = self._get_vlm_language_model()
        actions = self.flow_matching_head.sample_actions_llama(
            prefix_embs=prefix_embs,
            prefix_pad_masks=prefix_pad_masks,
            prefix_att_masks=prefix_att_masks,
            vlm_language_model=vlm_lm,
            state=state,
            device=device,
        )

        # Unnormalize actions if MEAN_STD normalization was used
        if self.use_action_norm:
            actions = actions * self.action_std.to(actions.device) + self.action_mean.to(actions.device)
            # Map gripper (dim 6) from [0,1] to [+1,-1] so that client-side
            # -gripper + binarize(>0.5 → open) produces correct LIBERO commands
            actions[:, :, 6] = 1.0 - 2.0 * actions[:, :, 6]

        actions_np = actions.cpu().float().numpy()
        return {"normalized_actions": actions_np.tolist()}
    elif vlm_type != "paligemma":
        # Use prefix cache mode for non-PaliGemma VLMs
        prefix_embs, prefix_pad_masks, prefix_att_masks = self._prepare_prefix_generic(examples)
        actions = self.flow_matching_head.sample_actions_prefix_cache(
            prefix_embs=prefix_embs,
            prefix_pad_masks=prefix_pad_masks,
            prefix_att_masks=prefix_att_masks,
            state=state,
            device=device,
        )

        # Unnormalize actions if MEAN_STD normalization was used
        if self.use_action_norm:
            actions = actions * self.action_std.to(actions.device) + self.action_mean.to(actions.device)
            # Map gripper (dim 6) from [0,1] to [+1,-1] so that client-side
            # -gripper + binarize(>0.5 → open) produces correct LIBERO commands
            actions[:, :, 6] = 1.0 - 2.0 * actions[:, :, 6]

        actions_np = actions.cpu().float().numpy()
        return {"normalized_actions": actions_np.tolist()}

    # Traditional PaliGemma inference path below
    # Bypass _prepare_prefix — use openpi-style embed_prefix directly
    # Inline PaligemmaTokenizer logic (matches openpi's PaligemmaTokenizer)
    # to avoid importing openpi which pulls in jax as a top-level dependency.
    import torchvision.transforms.functional as TF
    import numpy as np_
    import math as _math

    # Ensure tokenizer is initialized (reuse existing _init_tokenizer)
    if self._tokenizer is None and not hasattr(self, '_hf_tokenizer'):
        self._init_tokenizer()

    _PREDICT_MAX_LEN = 200  # match openpi PaligemmaTokenizer default

    def _tokenize_openpi_style(text, max_len=_PREDICT_MAX_LEN):
        """Tokenize text in openpi PaligemmaTokenizer format: BOS + cleaned_text + newline, padded to max_len."""
        cleaned = str(text).strip().replace("_", " ").replace("\n", " ")
        if hasattr(self, '_hf_tokenizer') and self._hf_tokenizer is not None:
            bos_id = self._hf_tokenizer.bos_token_id
            text_ids = self._hf_tokenizer.encode(cleaned, add_special_tokens=False)
            newline_ids = self._hf_tokenizer.encode("\n", add_special_tokens=False)
            ids = ([bos_id] if bos_id is not None else []) + text_ids + newline_ids
        else:
            ids = self._tokenizer.encode(cleaned, add_bos=True) + self._tokenizer.encode("\n")
        tokens_len = len(ids)
        if tokens_len < max_len:
            mask = [True] * tokens_len + [False] * (max_len - tokens_len)
            ids = ids + [0] * (max_len - tokens_len)
        else:
            ids = ids[:max_len]
            mask = [True] * max_len
        return np_.asarray(ids), np_.asarray(mask)

    def _proc_img(im):
        t = torch.from_numpy(im.copy()).float()
        if t.ndim == 3 and t.shape[-1] == 3:
            t = t.permute(2, 0, 1)
        t = t / 255.0
        t = TF.resize(t, [224, 224], antialias=True)
        t = TF.normalize(t, mean=[0.5]*3, std=[0.5]*3)
        return t

    ex = examples[0]
    imgs_raw = ex['image'] if isinstance(ex['image'], list) else [ex['image']]
    _dtype = next(self.parameters()).dtype
    img_tensors = [_proc_img(im).unsqueeze(0).to(device).to(_dtype) for im in imgs_raw]
    while len(img_tensors) < 3:
        img_tensors.append(torch.full((1, 3, 224, 224), -1.0, device=device, dtype=_dtype))
    img_masks_list = [torch.tensor([True], device=device)] * len(imgs_raw) + \
                     [torch.tensor([False], device=device)] * (3 - len(imgs_raw))

    tokens, masks = _tokenize_openpi_style(ex['lang'])
    tokens_t = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
    masks_t = torch.tensor(masks, dtype=torch.bool).unsqueeze(0).to(device)

    embs_list, pad_list, att_list = [], [], []
    for img_t, img_m in zip(img_tensors, img_masks_list):
        img_emb = self.vlm_interface.model.get_image_features(img_t)
        bsize, n_embs = img_emb.shape[:2]
        embs_list.append(img_emb)
        pad_list.append(img_m[:, None].expand(bsize, n_embs))
        att_list += [0] * n_embs

    lang_emb = self.vlm_interface.model.embed_tokens(tokens_t)
    lang_emb = lang_emb * _math.sqrt(lang_emb.shape[-1])
    embs_list.append(lang_emb)
    pad_list.append(masks_t)
    att_list += [0] * lang_emb.shape[1]

    prefix_embs = torch.cat(embs_list, dim=1)
    prefix_pad_masks = torch.cat(pad_list, dim=1)
    att_tensor = torch.tensor(att_list, dtype=torch.bool, device=device)
    prefix_att_masks = att_tensor[None, :].expand(bsize, -1)

    # DEBUG: log prefix stats for comparison with openpi
    if not hasattr(self, '_debug_count'):
        self._debug_count = 0
    if self._debug_count < 3:
        # Log image info from examples
        ex0 = examples[0]
        img0 = ex0['image']
        if isinstance(img0, list):
            for vi, v in enumerate(img0):
                import numpy as np
                if isinstance(v, np.ndarray):
                    print(f"[DEBUG] image[{vi}]: type=ndarray, dtype={v.dtype}, shape={v.shape}, range=[{v.min()}, {v.max()}]")
                elif isinstance(v, torch.Tensor):
                    print(f"[DEBUG] image[{vi}]: type=tensor, dtype={v.dtype}, shape={v.shape}, range=[{v.min()}, {v.max()}]")
                else:
                    print(f"[DEBUG] image[{vi}]: type={type(v)}")
        print(f"[DEBUG] prefix shape={prefix_embs.shape}, "
              f"mean={prefix_embs.float().mean():.6f}, std={prefix_embs.float().std():.6f}, "
              f"img[0:4]={prefix_embs[0,0,:4].float().cpu().tolist()}, "
              f"lang[768:772]={prefix_embs[0,768,:4].float().cpu().tolist()}")
        import sys; sys.stdout.flush(); sys.stderr.flush()
        self._debug_count += 1

    vlm_lm = self._get_vlm_language_model()

    # Use openpi-style KV cache inference directly (proven to work in AB test)
    from AlphaBrain.model.modules.action_model.pi0_flow_matching_head.openpi_inference import make_att_2d_masks
    from transformers.models.gemma import modeling_gemma

    bsize = prefix_pad_masks.shape[0]

    # Step 1: prefix → KV cache through VLM language model
    prefix_att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
    prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
    prefix_att_4d = prefix_att_2d[:, None, :, :]
    prefix_att_4d = torch.where(prefix_att_4d, 0.0, -2.3819763e38)

    vlm_lm.config._attn_implementation = "eager"
    prefix_output = vlm_lm.forward(
        inputs_embeds=prefix_embs,
        attention_mask=prefix_att_4d,
        position_ids=prefix_position_ids,
        past_key_values=None,
        use_cache=True,
    )
    past_key_values = prefix_output.past_key_values
    # For transformers >= 4.45: DynamicCache accumulates K,V across denoising steps.
    # Convert to legacy tuple format so we can recreate a fresh cache each step.
    if hasattr(past_key_values, 'to_legacy_cache'):
        _vlm_kv_legacy = past_key_values.to_legacy_cache()
        _use_dynamic_cache = True
    else:
        _use_dynamic_cache = False

    # Step 2: iterative denoising through action expert
    expert_model = self.flow_matching_head.action_expert.model.model
    num_steps = self.flow_matching_head.num_inference_steps
    dt = -1.0 / num_steps
    dt_t = torch.tensor(dt, dtype=torch.float32, device=device)

    noise = torch.randn(bsize, self.action_horizon, self.action_dim, dtype=torch.float32, device=device)
    x_t = noise
    time = torch.tensor(1.0, dtype=torch.float32, device=device)

    while time >= -dt_t / 2:
        expanded_time = time.expand(bsize)
        suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.flow_matching_head.embed_suffix(
            state, x_t, expanded_time
        )

        suffix_len = suffix_pad_masks.shape[1]
        prefix_len = prefix_pad_masks.shape[1]

        prefix_pad_2d = prefix_pad_masks[:, None, :].expand(bsize, suffix_len, prefix_len)
        suffix_att_2d = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
        full_att_2d = torch.cat([prefix_pad_2d, suffix_att_2d], dim=2)
        full_att_4d = full_att_2d[:, None, :, :]
        full_att_4d = torch.where(full_att_4d, 0.0, -2.3819763e38)

        prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
        suffix_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1

        expert_model.config._attn_implementation = "eager"
        # Rebuild fresh DynamicCache each step: prevents expert K,V from
        # accumulating across denoising steps (DynamicCache is stateful in 4.45+).
        if _use_dynamic_cache:
            from transformers import DynamicCache
            _kv_for_step = DynamicCache.from_legacy_cache(_vlm_kv_legacy)
        else:
            _kv_for_step = past_key_values
        suffix_output = expert_model.forward(
            inputs_embeds=suffix_embs,
            attention_mask=full_att_4d,
            position_ids=suffix_position_ids,
            past_key_values=_kv_for_step,
            use_cache=False,
            adarms_cond=adarms_cond,
        )

        suffix_out = suffix_output.last_hidden_state
        suffix_out = suffix_out[:, -self.action_horizon:]
        suffix_out = suffix_out.to(dtype=torch.float32)
        v_t = self.flow_matching_head.action_out_proj(suffix_out)

        x_t = x_t + dt_t * v_t
        time = time + dt_t

    actions = x_t

    # Unnormalize actions if MEAN_STD normalization was used
    if self.use_action_norm:
        actions = actions * self.action_std.to(actions.device) + self.action_mean.to(actions.device)
        # Map gripper (dim 6) from [0,1] to [+1,-1] so that client-side
        # -gripper + binarize(>0.5 → open) produces correct LIBERO commands
        actions[:, :, 6] = 1.0 - 2.0 * actions[:, :, 6]

    actions_np = actions.cpu().float().numpy()

    # Return actions (unnormalized if norm was enabled, raw otherwise)
    return {"normalized_actions": actions_np.tolist()}

    # Enable gradient checkpointing to save memory
    if hasattr(self.vlm_interface, 'model'):
        if hasattr(self.vlm_interface.model, 'gradient_checkpointing_enable'):
            self.vlm_interface.model.gradient_checkpointing_enable()
        # For PaliGemmaVLM: enable on sub-models
        if hasattr(self.vlm_interface.model, 'language_model') and hasattr(self.vlm_interface.model.language_model, 'gradient_checkpointing_enable'):
            self.vlm_interface.model.language_model.gradient_checkpointing_enable()
        if hasattr(self.vlm_interface.model, 'vision_tower') and hasattr(self.vlm_interface.model.vision_tower, 'gradient_checkpointing_enable'):
            self.vlm_interface.model.vision_tower.gradient_checkpointing_enable()
    # Freeze VLM language model and lm_head to save GPU memory
    # Only train: vision_tower + projector + action_expert + flow matching head
    if hasattr(self.vlm_interface, 'model'):
        for name, param in self.vlm_interface.model.named_parameters():
            if 'language_model' in name or 'lm_head' in name:
                param.requires_grad = False
        frozen = sum(1 for p in self.vlm_interface.model.parameters() if not p.requires_grad)
        total = sum(1 for p in self.vlm_interface.model.parameters())
        trainable_m = sum(p.numel() for p in self.vlm_interface.model.parameters() if p.requires_grad) / 1e6
        print(f'[PaliGemmaOFT] VLM: frozen {frozen}/{total} params, trainable {trainable_m:.0f}M (vision+projector)')

    if hasattr(self.flow_matching_head, 'action_expert') and hasattr(self.flow_matching_head.action_expert.model, 'gradient_checkpointing_enable'):
        self.flow_matching_head.action_expert.model.gradient_checkpointing_enable()

Llama OFT

LlamaOFT

Llama-OFT Framework

Uses Llama 3.2 Vision as backbone with action special token for continuous action prediction. Mirrors QwenOFT but swaps Qwen for Llama 3.2 Vision.

Llama_OFT

Llama_OFT(config: Optional[dict] = None, **kwargs)

Bases: BaseFramework

Llama 3.2 Vision + action token OFT framework. Predicts continuous actions via L1 regression on action token hidden states.

Source code in AlphaBrain/model/framework/LlamaOFT.py
def __init__(self, config: Optional[dict] = None, **kwargs) -> None:
    super().__init__()
    self.config = config
    self.llama_vl_interface = get_vlm_model(config=self.config)

    # Align action hidden dim with LLM hidden size
    config.framework.action_model.action_hidden_dim = self.llama_vl_interface.model.config.text_config.hidden_size
    self.action_model = get_action_model(config=self.config)

    # Enable gradient checkpointing for memory efficiency (11B model)
    if hasattr(self.llama_vl_interface.model, 'gradient_checkpointing_enable'):
        self.llama_vl_interface.model.gradient_checkpointing_enable()
        logger.info("Enabled gradient checkpointing for Llama model")

    self.future_action_window_size = config.framework.action_model.future_action_window_size
    self.past_action_window_size = config.framework.action_model.past_action_window_size
    self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size

    self.action_token = "🔍"
    self.action_token_id = self.llama_vl_interface.processor.tokenizer(
        "🔍", add_special_tokens=False
    )["input_ids"][0]

    self.l1_loss = nn.L1Loss()

Qwen family

QwenOFT

Qwen-OFT Framework

A lightweight implementation that uses an action special token to parallelly predict continuous actions conditioned on multi-view images plus a language instruction (shares parameters with the VLM). Inspired by OpenVLA-OFT Key Points: - Qwen2.5 vision-language backbone - Injects an action special token into the VLM - Continuous action prediction via L1 regression over the action special token hidden states

How to add special tokens to Qwen2.5:

download our model checkpoint with special tokens added: https://huggingface.co/AlphaBrain/Qwen2.5-VL-3B-Instruct-Action or /AlphaBrain/model/modules/vlm/tools/add_qwen_special_tokens/README.md (adpat a little code)

Qwenvl_OFT

Qwenvl_OFT(config: Optional[dict] = None, **kwargs)

Bases: BaseFramework

Multimodal vision-language-action model.

Components
  • Qwen2.5 VL interface for fused language/vision token embeddings
  • Layer-wise QFormer for multi-layer feature aggregation
  • DINO encoder for dense multi-view spatial tokens
  • DiT diffusion head for future action sequence modeling

Focus: Predict future continuous actions conditioned on images + instruction.

Construct all submodules and cache key configuration values.

Parameters:

Name Type Description Default
config Optional[dict]

Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections.

None
**kwargs

Reserved for future overrides (unused).

{}
Source code in AlphaBrain/model/framework/QwenOFT.py
def __init__(
    self,
    config: Optional[dict] = None,
    **kwargs,
) -> None:
    """
    Construct all submodules and cache key configuration values.

    Args:
        config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections.
        **kwargs: Reserved for future overrides (unused).
    """
    super().__init__()
    self.config = config
    self.qwen_vl_interface = get_vlm_model(config=self.config)
    # align dims --> we should put them to config or no?
    config.framework.action_model.action_hidden_dim = self.qwen_vl_interface.model.config.hidden_size
    self.action_model = get_action_model(config=self.config)

    self.future_action_window_size = config.framework.action_model.future_action_window_size
    self.past_action_window_size = config.framework.action_model.past_action_window_size
    self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size
    # self.hidden_dim = config.framework.action_model.action_hidden_dim

    self.action_token = "🔍" # TODO also can add spacail token to Qwen, but too complex
    self.action_token_id = self.qwen_vl_interface.processor.tokenizer("🔍", add_special_tokens=False)["input_ids"][0]

    # L1 损失
    self.l1_loss = nn.L1Loss()
forward
forward(examples: List[dict] = None, **kwargs) -> Tuple

训练前向:直接回归未来动作(无扩散)。

Flow
  1. Build QwenVL inputs (images + instruction tokens)
  2. Extract hidden states from configured layer range
  3. Predict action and compute L1 loss

Parameters:

Name Type Description Default
examples List[dict]

List[dict], each dict requires: - image: List[PIL.Image] (multi-view) - lang: str instruction - action: np.ndarray or list shaped [T, action_dim]

None
**kwargs

Reserved.

{}

Returns:

Name Type Description
dict Tuple

action_loss (torch.Tensor): Scalar diffusion noise prediction loss.

Source code in AlphaBrain/model/framework/QwenOFT.py
def forward(
    self,
    examples: List[dict] = None,
    **kwargs,
) -> Tuple:
    """
    训练前向:直接回归未来动作(无扩散)。

    Flow:
      1. Build QwenVL inputs (images + instruction tokens)
      2. Extract hidden states from configured layer range
      7. Predict action and compute L1 loss

    Args:
        examples: List[dict], each dict requires:
            - image: List[PIL.Image] (multi-view)
            - lang: str instruction
            - action: np.ndarray or list shaped [T, action_dim]
        **kwargs: Reserved.

    Returns:
        dict:
            action_loss (torch.Tensor): Scalar diffusion noise prediction loss.
    """
    batch_images = [example["image"] for example in examples]  #  [B,[PLT]]
    instructions = [example["lang"] for example in examples]  # [B, str]
    actions = [example["action"] for example in examples]  # label [B, len, 7]

    # step 0: add special action token to instruction
    action_tokens = self.action_token* self.chunk_len #can't add " " between two tokens, otherwise will be tokenized to multiple tokens
    prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{action_tokens}<action>."
    instructions = [instruction + prompt_suffix for instruction in instructions]

    # Step 1: QWenVL input format
    qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        qwenvl_outputs = self.qwen_vl_interface(
            **qwen_inputs,
            output_attentions=False,
            output_hidden_states=True,
            return_dict=True,
        )
        # last_hidden_state: [B, seq_len, H]
        last_hidden = qwenvl_outputs.hidden_states[-1]   # [B, L, H]

    # Step 4: Action Expert Forward and Loss
    with torch.autocast("cuda", dtype=torch.float32):
        # 提取动作 token embedding 作为动作预测查询
        input_ids = qwen_inputs.get("input_ids", None)
        action_queries = self._gather_action_token_embeddings(last_hidden, input_ids, action_token_id=self.action_token_id)  # [B, chunk_len, H]
        pred_actions = self.action_model.predict_action(action_queries)  # (B, chunk_len, action_dim)

        # 标签对齐:取最后 chunk_len 段
        actions = torch.tensor(
            np.array(actions), device=pred_actions.device, dtype=pred_actions.dtype
        )  # [B, T_full, action_dim]
        actions_target = actions[:, -(self.future_action_window_size+1):, :]  # (B, chunk_len, action_dim)

        # 计算 L1 损失
        action_loss = self.l1_loss(pred_actions, actions_target)

    return {"action_loss": action_loss}
predict_action
predict_action(batch_images: List = None, instructions: List[str] = None, examples: List[dict] = None, **kwargs) -> np.ndarray

推理:单次前向直接回归未来动作(无扩散采样)。

Accepts two input formats
  • Flat format (from M1Inference websocket client): batch_images + instructions
  • Legacy format: examples (list of dicts with "image" and "lang" keys)
Steps
  1. Resize images to training resolution (if specified)
  2. Encode with QwenVL (hidden states retained)
  3. Return normalized action trajectory

Returns:

Name Type Description
dict ndarray

normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.

Source code in AlphaBrain/model/framework/QwenOFT.py
@torch.inference_mode()
def predict_action(
    self,
    batch_images: List = None,
    instructions: List[str] = None,
    examples: List[dict] = None,        # todo:优化example接口
    **kwargs,
) -> np.ndarray:
    """
    推理:单次前向直接回归未来动作(无扩散采样)。

    Accepts two input formats:
      - Flat format (from M1Inference websocket client): batch_images + instructions
      - Legacy format: examples (list of dicts with "image" and "lang" keys)

    Steps:
      1. Resize images to training resolution (if specified)
      2. Encode with QwenVL (hidden states retained)
      6. Return normalized action trajectory

    Returns:
        dict:
            normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.
    """
    # Support both flat format (batch_images/instructions) and legacy examples format
    if examples is not None:
        batch_images = [to_pil_preserve(example["image"]) for example in examples]  #  [B,[PLT]]
        instructions = [example["lang"] for example in examples]  # [B, str]
    else:
        batch_images = [to_pil_preserve(imgs) for imgs in batch_images]  # [B, [PLT]]

    train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
    if train_obs_image_size:
        batch_images = resize_images(batch_images, target_size=train_obs_image_size)

    # step 0: add special action token to instruction
    action_tokens = self.action_token* self.chunk_len #can't add " " between two tokens, otherwise will be tokenized to multiple tokens
    prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{action_tokens}<action>."
    instructions = [instruction + prompt_suffix for instruction in instructions]

    # Step 1: QWenVL input format
    qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        qwenvl_outputs = self.qwen_vl_interface(
            **qwen_inputs,
            output_attentions=False,
            output_hidden_states=True,
            return_dict=True,
        )
        # last_hidden_state: [B, seq_len, H]
        last_hidden = qwenvl_outputs.hidden_states[-1]   # [B, L, H]

    # Step 4: Action Expert Forward and Loss
    with torch.autocast("cuda", dtype=torch.float32):
        # 提取动作 token embedding 作为动作预测查询
        input_ids = qwen_inputs.get("input_ids", None)
        action_queries = self._gather_action_token_embeddings(last_hidden, input_ids, action_token_id=self.action_token_id)  # [B, chunk_len, H]
        pred_actions = self.action_model.predict_action(action_queries)  # (B, chunk_len, action_dim)

    normalized_actions = pred_actions.detach().cpu().numpy()
    return {"normalized_actions": normalized_actions}
get_action_queries
get_action_queries(batch_images: List = None, instructions: List[str] = None) -> torch.Tensor

Extract action_queries from frozen VLM without going through the action head.

Returns:

Name Type Description
action_queries Tensor

(B, chunk_len, H) tensor on model device

Source code in AlphaBrain/model/framework/QwenOFT.py
@torch.no_grad()
def get_action_queries(
    self,
    batch_images: List = None,
    instructions: List[str] = None,
) -> torch.Tensor:
    """
    Extract action_queries from frozen VLM without going through the action head.

    Returns:
        action_queries: (B, chunk_len, H) tensor on model device
    """
    from deployment.model_server.tools.image_tools import to_pil_preserve
    batch_images = [to_pil_preserve(imgs) for imgs in batch_images]

    train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
    if train_obs_image_size:
        batch_images = resize_images(batch_images, target_size=train_obs_image_size)

    action_tokens = self.action_token * self.chunk_len
    prompt_suffix = f" Please predict the next {self.chunk_len} robot actions: <action>{action_tokens}<action>."
    instructions = [inst + prompt_suffix for inst in instructions]

    qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(
        images=batch_images, instructions=instructions
    )
    with torch.autocast("cuda", dtype=torch.bfloat16):
        qwenvl_outputs = self.qwen_vl_interface(
            **qwen_inputs,
            output_attentions=False,
            output_hidden_states=True,
            return_dict=True,
        )
        last_hidden = qwenvl_outputs.hidden_states[-1]  # [B, L, H]

    with torch.autocast("cuda", dtype=torch.float32):
        input_ids = qwen_inputs.get("input_ids", None)
        action_queries = self._gather_action_token_embeddings(
            last_hidden, input_ids, action_token_id=self.action_token_id
        )
    return action_queries  # (B, chunk_len, H)
get_vla_action
get_vla_action(batch_images: List = None, instructions: List[str] = None)

Get both action_queries and VLA base action predictions (frozen).

Returns:

Name Type Description
action_queries

(B, chunk_len, H) tensor

vla_actions

(B, chunk_len, action_dim) tensor (normalized)

Source code in AlphaBrain/model/framework/QwenOFT.py
@torch.no_grad()
def get_vla_action(
    self,
    batch_images: List = None,
    instructions: List[str] = None,
):
    """
    Get both action_queries and VLA base action predictions (frozen).

    Returns:
        action_queries: (B, chunk_len, H) tensor
        vla_actions: (B, chunk_len, action_dim) tensor (normalized)
    """
    action_queries = self.get_action_queries(batch_images, instructions)
    with torch.autocast("cuda", dtype=torch.float32):
        vla_actions = self.action_model.predict_action(action_queries)  # (B, chunk_len, action_dim)
    return action_queries, vla_actions

QwenPI

Qwen-GROOT Framework A lightweight implementation that Qwen2.5-vl + Flow-matching head to directly predict continuous actions Flow-matching header is copyright from GR00T N1.5, but a sample MoE inspired by PI_0

Extended (2026-04): World model backbone support (V-JEPA, Cosmos, Wan). When a world model VLM is used, forward_all_layers() extracts per-backbone- block features so each PI action head layer cross-attends to a DIFFERENT backbone layer (true layerwise cross-attention, no replication).

Qwen_PI

Qwen_PI(config: Optional[dict] = None, **kwargs)

Bases: BaseFramework

Multimodal vision-language-action model.

Components
  • Qwen2.5 VL interface for fused language/vision token embeddings
  • Layer-wise cross DiT diffusion head
World model mode

When the VLM is a world model (Cosmos, Wan), per-DiT-block features are extracted via forward_all_layers() so each PI action head layer cross-attends to a DIFFERENT backbone layer. framework.qwenvl.num_vl_layers must match the backbone block count (28 for Cosmos, 30 for Wan).

Focus: Predict future continuous actions conditioned on images + instruction.

Source code in AlphaBrain/model/framework/QwenPI.py
def __init__(
    self,
    config: Optional[dict] = None,
    **kwargs,
) -> None:
    super().__init__()
    self.config = config
    self.qwen_vl_interface = get_vlm_model(config=self.config)

    # ----- Detect VLM backend type -----
    self._world_model_mode = _is_world_model_vlm(self.qwen_vl_interface)
    self._vlm_backend = self._detect_vlm_backend(config)
    logger.info("[QwenPI] VLM backend: %s", self._vlm_backend)

    if self._world_model_mode:
        llm_hidden_size = self.qwen_vl_interface.model.config.hidden_size
        num_vl_layers = getattr(config.framework.qwenvl, "num_vl_layers", 28)
        _backend = "unknown"
        if hasattr(self.qwen_vl_interface, "wm_config"):
            _backend = getattr(self.qwen_vl_interface.wm_config, "backend", "unknown")
        logger.info(
            "[QwenPI] World model mode: backend=%s, hidden_size=%d, "
            "num_vl_layers(backbone blocks for layerwise XAttn)=%d",
            _backend, llm_hidden_size, num_vl_layers,
        )
    elif self._vlm_backend == "llama":
        # Llama 3.2 Vision: VLM hidden_size may differ from DiT hidden_size
        model_cfg = self.qwen_vl_interface.model.config
        self._llama_vlm_hidden_size = model_cfg.text_config.hidden_size  # 4096
        llama_num_layers = model_cfg.text_config.num_hidden_layers  # 40

        # DiT dimensions: use action_model config if available, else match VLM
        dit_hidden = getattr(config.framework.action_model, 'action_hidden_dim', self._llama_vlm_hidden_size)
        dit_layers = getattr(config.framework.action_model, 'num_dit_layers', llama_num_layers)

        # If DiT hidden != VLM hidden, we need a projection layer
        if dit_hidden != self._llama_vlm_hidden_size:
            self._vlm_to_dit_proj = nn.Linear(self._llama_vlm_hidden_size, dit_hidden, bias=False)
            logger.info("[LlamaPi0FM] Added VLM→DiT projection: %d%d", 
                       self._llama_vlm_hidden_size, dit_hidden)
        else:
            self._vlm_to_dit_proj = None

        llm_hidden_size = dit_hidden  # DiT sees this dimension
        num_vl_layers = dit_layers
        logger.info("[LlamaPi0FM] Llama VLM: hidden=%d, layers=%d → DiT: hidden=%d, layers=%d", 
                   self._llama_vlm_hidden_size, llama_num_layers, dit_hidden, dit_layers)

        # Enable gradient checkpointing for 11B model
        if hasattr(self.qwen_vl_interface.model, 'gradient_checkpointing_enable'):
            self.qwen_vl_interface.model.gradient_checkpointing_enable()
            logger.info("Enabled gradient checkpointing for Llama model")
    else:
        # Standard Qwen2.5-VL path (original behaviour)
        num_vl_layers, llm_hidden_size = 36, self.qwen_vl_interface.model.config.hidden_size

    # Ensure qwenvl namespace exists for action model config (LayerwiseFM reads from it)
    if not hasattr(self.config.framework, 'qwenvl'):
        from omegaconf import OmegaConf, DictConfig
        # Handle AccessTrackedConfig wrapper
        fw_cfg = self.config.framework
        if hasattr(fw_cfg, '_cfg') and isinstance(fw_cfg._cfg, DictConfig):
            fw_cfg._cfg.qwenvl = OmegaConf.create({"vl_hidden_dim": llm_hidden_size, "num_vl_layers": num_vl_layers})
            fw_cfg._children.pop('qwenvl', None)  # clear cached child
        else:
            fw_cfg.qwenvl = OmegaConf.create({"vl_hidden_dim": llm_hidden_size, "num_vl_layers": num_vl_layers})
    else:
        self.config.framework.qwenvl.vl_hidden_dim = llm_hidden_size
        self.config.framework.qwenvl.num_vl_layers = num_vl_layers

    self.action_model: LayerwiseFlowmatchingActionHead = get_action_model(config=self.config)

    self.future_action_window_size = config.framework.action_model.future_action_window_size
    self.past_action_window_size = config.framework.action_model.past_action_window_size
    self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size

    # Determine whether state was used during training
    self.use_state = getattr(
        getattr(getattr(config, "datasets", None), "vla_data", None),
        "include_state", False,
    )
    if self.use_state in ["False", False, None, "false", ""]:
        self.use_state = False
    else:
        self.use_state = True
    logger.info("[QwenPI] use_state=%s", self.use_state)

    # Video loss is always enabled (unified mode)
    self._video_loss_weight = float(
        getattr(
            getattr(getattr(config, 'framework', None), 'world_model', None),
            'video_loss_weight', 1.0,
        )
    ) if (hasattr(config, 'framework') and hasattr(config.framework, 'world_model')) else 1.0
    logger.info("[QwenPI] video_loss_weight=%.3f", self._video_loss_weight)
forward
forward(examples: List[dict] = None, **kwargs) -> Tuple

Parameters:

Name Type Description Default
examples List[dict]

List[dict], each dict requires: - image: List[PIL.Image] (multi-view) - lang: str instruction - action: np.ndarray or list shaped [T, action_dim]

None

Returns: dict: action_loss (torch.Tensor): Scalar diffusion noise prediction loss.

Source code in AlphaBrain/model/framework/QwenPI.py
def forward(
    self,
    examples: List[dict] = None,
    **kwargs,
) -> Tuple:
    """
    Args:
        examples: List[dict], each dict requires:
            - image: List[PIL.Image] (multi-view)
            - lang: str instruction
            - action: np.ndarray or list shaped [T, action_dim]
    Returns:
        dict:
            action_loss (torch.Tensor): Scalar diffusion noise prediction loss.
    """
    batch_images = [example["image"] for example in examples]  #  [B,[PLT]]
    instructions = [example["lang"] for example in examples]  # [B, str]
    actions = [example["action"] for example in examples]  # label [B, len, 7]

    state = [example["state"] for example in examples] if (self.use_state and "state" in examples[0]) else None

    video_loss = None

    # ===================================================================
    # Video loss path (training)
    # ===================================================================
    has_next_images = (
        len(examples) > 0
        and "next_image" in examples[0]
        and examples[0]["next_image"] is not None
    )
    has_visual_encoder = hasattr(self.qwen_vl_interface, "visual_encoder")

    if has_next_images and has_visual_encoder:
        wm_encoder = self.qwen_vl_interface.visual_encoder
        next_images_raw = [example.get("next_image") for example in examples]
        valid_mask = [img is not None for img in next_images_raw]

        if any(valid_mask):
            curr_images_flat = [
                imgs[0] if isinstance(imgs, (list, tuple)) else imgs
                for imgs in batch_images
            ]
            dummy_next = [
                next_images_raw[i] if valid_mask[i] else curr_images_flat[i]
                for i in range(len(examples))
            ]

            with torch.autocast("cuda", dtype=torch.bfloat16):
                curr_pv = wm_encoder.preprocess(curr_images_flat)
                next_pv = wm_encoder.preprocess(dummy_next)
                qwenvl_outputs, video_loss_raw = self.qwen_vl_interface.forward_with_video_loss(
                    curr_pv, instructions, next_pv
                )
                last_hidden = qwenvl_outputs.hidden_states[-1]  # [B, L, H]

                # Single-layer features from video loss path; replicate for all PI DiT layers
                expected_layers = len(self.action_model.model.transformer_blocks)
                vl_embs_list = [last_hidden] * expected_layers
                base_hidden = last_hidden

            if not all(valid_mask):
                valid_count = sum(valid_mask)
                scale = len(valid_mask) / max(valid_count, 1)
                video_loss = video_loss_raw * scale
            else:
                video_loss = video_loss_raw
        else:
            pass  # fall through to inference path

    if not (has_next_images and has_visual_encoder):
        # ===============================================================
        # Standard encode (inference / no next_image)
        # ===============================================================
        # Step 1: QWenVL input format
        qwen_inputs = self._build_vlm_inputs(batch_images, instructions)
        with torch.autocast("cuda", dtype=torch.bfloat16):
            if self._world_model_mode and hasattr(self.qwen_vl_interface, 'forward_all_layers'):
                # World model: extract per-backbone-block features for layerwise XAttn
                qwenvl_outputs = self.qwen_vl_interface.forward_all_layers(
                    qwen_inputs["pixel_values"], instructions
                )
            else:
                qwenvl_outputs = self.qwen_vl_interface(
                    **qwen_inputs,
                    output_attentions=False,
                    output_hidden_states=True,
                    return_dict=True,
                )
            vl_embs_list, base_hidden = self._extract_vl_embs(qwenvl_outputs)
        video_loss = None

    # Step 4: Action Expert Forward and Loss
    with torch.autocast("cuda", dtype=torch.float32):
        actions = torch.tensor(
            np.array(actions), device=base_hidden.device, dtype=base_hidden.dtype
        )  # [B, T_full, action_dim]
        actions_target = actions[:, -(self.future_action_window_size+1):, :]  # (B, chunk_len, action_dim)

        repeated_diffusion_steps = (
            self.config.trainer.get("repeated_diffusion_steps", 4) if self.config and self.config.trainer else 4
        )
        repeated_diffusion_steps = 2  # NO repeat for big action FM
        actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1)
        vl_embs_list_repeated = [h.repeat(repeated_diffusion_steps, 1, 1) for h in vl_embs_list]

        state_repeated = None
        if state is not None:
            state = torch.tensor(
                np.array(state), device=base_hidden.device, dtype=base_hidden.dtype
            )
            # Truncate state to match model's expected state_dim (e.g. LIBERO sends 8-dim, model expects 7-dim)
            expected_state_dim = getattr(
                getattr(getattr(self.config, 'framework', None), 'action_model', None),
                'state_dim', None
            )
            if expected_state_dim and state.shape[-1] > expected_state_dim:
                state = state[..., :expected_state_dim]
            # Ensure state is 3D (B, T, state_dim) for action model
            if state.ndim == 2:
                state = state.unsqueeze(1)
            state_repeated = state.repeat(repeated_diffusion_steps, 1, 1)

        action_loss = self.action_model(vl_embs_list_repeated, actions_target_repeated, state_repeated)

    result = {"action_loss": action_loss}
    if video_loss is not None:
        result["video_loss"] = video_loss
        result["total_loss"] = action_loss + self._video_loss_weight * video_loss
    return result
predict_action
predict_action(examples: List[dict] = None, batch_images: List = None, instructions: List[str] = None, states=None, **kwargs) -> np.ndarray

Inference: single forward pass to regress future actions via flow-matching sampling through the layerwise DiT.

Supports two input formats
  • examples: List[dict] with keys "image", "lang", "state" (legacy)
  • batch_images + instructions: direct arguments

Returns:

Name Type Description
dict ndarray

normalized_actions (np.ndarray): Shape [B, T, action_dim].

Source code in AlphaBrain/model/framework/QwenPI.py
@torch.inference_mode()
def predict_action(
    self,
    examples: List[dict] = None,
    batch_images: List = None,
    instructions: List[str] = None,
    states=None,
    **kwargs,
) -> np.ndarray:
    """
    Inference: single forward pass to regress future actions via flow-matching
    sampling through the layerwise DiT.

    Supports two input formats:
      - examples: List[dict] with keys "image", "lang", "state" (legacy)
      - batch_images + instructions: direct arguments

    Returns:
        dict:
            normalized_actions (np.ndarray): Shape [B, T, action_dim].
    """
    if examples is not None:
        if type(examples) is not list:
            examples = [examples]
        from deployment.model_server.tools.image_tools import to_pil_preserve
        batch_images = [to_pil_preserve(example["image"]) for example in examples]
        instructions = [example["lang"] for example in examples]
        state = [example["state"] for example in examples] if (self.use_state and "state" in examples[0]) else None
    else:
        assert batch_images is not None and instructions is not None, \
            "Either examples or both batch_images and instructions must be provided"
        if isinstance(batch_images[0][0], np.ndarray):
            batch_images = [[Image.fromarray(img) for img in seq] for seq in batch_images]
        state = states if self.use_state else None

    train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
    if train_obs_image_size:
        batch_images = resize_images(batch_images, target_size=train_obs_image_size)

    # Step 1: QWenVL input format
    qwen_inputs = self._build_vlm_inputs(batch_images, instructions)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        if self._world_model_mode and hasattr(self.qwen_vl_interface, 'forward_all_layers'):
            # World model: extract per-backbone-block features for layerwise XAttn
            qwenvl_outputs = self.qwen_vl_interface.forward_all_layers(
                qwen_inputs["pixel_values"], instructions
            )
        else:
            qwenvl_outputs = self.qwen_vl_interface(
                **qwen_inputs,
                output_attentions=False,
                output_hidden_states=True,
                return_dict=True,
            )
        vl_embs_list, base_hidden = self._extract_vl_embs(qwenvl_outputs)

    state_tensor = None
    if state is not None:
        state_tensor = torch.from_numpy(np.array(state)).to(base_hidden.device, dtype=base_hidden.dtype)
        # Truncate state to match model expected state_dim
        expected_state_dim = getattr(
            getattr(getattr(self.config, "framework", None), "action_model", None),
            "state_dim", None,
        )
        if expected_state_dim and state_tensor.shape[-1] > expected_state_dim:
            state_tensor = state_tensor[..., :expected_state_dim]
        if state_tensor.ndim == 2:
            state_tensor = state_tensor.unsqueeze(1)

    # Step 4: Action Expert Forward
    with torch.autocast("cuda", dtype=torch.float32):
        pred_actions = self.action_model.predict_action(vl_embs_list, state_tensor)

    normalized_actions = pred_actions.detach().cpu().numpy()
    return {"normalized_actions": normalized_actions}

QwenGR00T

Qwen-GR00T Framework A lightweight implementation that Qwen-VL + Flow-matching head to directly predict continuous actions Flow-matching header is copyright from GR00T N1.5,

Qwen_GR00T

Qwen_GR00T(config: Optional[dict] = None, **kwargs)

Bases: BaseFramework

Multimodal vision-language-action model.

Components
  • Qwen2.5 VL interface for fused language/vision token embeddings
  • Layer-wise QFormer for multi-layer feature aggregation
  • DINO encoder for dense multi-view spatial tokens
  • DiT diffusion head for future action sequence modeling

Focus: Predict future continuous actions conditioned on images + instruction.

Construct all submodules and cache key configuration values.

Parameters:

Name Type Description Default
config Optional[dict]

Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections.

None
**kwargs

Reserved for future overrides (unused).

{}
Source code in AlphaBrain/model/framework/QwenGR00T.py
def __init__(
    self,
    config: Optional[dict] = None,
    **kwargs,
) -> None:
    """
    Construct all submodules and cache key configuration values.

    Args:
        config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections.
        **kwargs: Reserved for future overrides (unused).
    """
    super().__init__()
    self.config = config
    self.qwen_vl_interface = get_vlm_model(config=self.config)
    # align dims --> we should put them to config or no?
    # align dims: use world_model hidden_size (post-fusion) if available, else VLM hidden_size
    wm_cfg = getattr(self.config.framework, "world_model", None)
    if wm_cfg is not None and getattr(wm_cfg, "hidden_size", None) is not None:
        self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = wm_cfg.hidden_size
    else:
        self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = self.qwen_vl_interface.model.config.hidden_size

    self.action_model: FlowmatchingActionHead = get_action_model(config=self.config)  # 修复后续引用

    self.future_action_window_size = config.framework.action_model.future_action_window_size
    self.past_action_window_size = config.framework.action_model.past_action_window_size
    self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size

    # Determine whether state was used during training (controls state_encoder usage)
    self.use_state = getattr(
        getattr(getattr(config, 'datasets', None), 'vla_data', None),
        'include_state', False
    )
    if self.use_state in ["False", False, None, "false", ""]:
        self.use_state = False
    else:
        self.use_state = True
    logger.info(f"[QwenGR00T] use_state={self.use_state} (from config.datasets.vla_data.include_state)")

    # Freeze state_encoder if state is not used (prevents DeepSpeed gradient deadlock)
    if not self.use_state and hasattr(self, 'action_model'):
        if hasattr(self.action_model, 'state_encoder') and self.action_model.state_encoder is not None:
            self.action_model.state_encoder.requires_grad_(False)
            logger.info("[QwenGR00T] Froze state_encoder (include_state=false)")

    # Video loss weight (hyperparameter, not a toggle).
    # WM backbones auto-enable video loss via has_visual_encoder check in forward().
    self._video_loss_weight = float(
        getattr(
            getattr(getattr(config, 'framework', None), 'world_model', None),
            'video_loss_weight', 1.0,
        )
    ) if (hasattr(config, 'framework') and hasattr(config.framework, 'world_model')) else 1.0
    logger.info("[QwenGR00T] video_loss_weight=%.3f", self._video_loss_weight)
forward
forward(examples: List[dict] = None, **kwargs) -> Tuple

Run a full training forward pass, with video loss when next_image is available.

When next_image is available, performs a SINGLE DiT forward that simultaneously yields action visual tokens and the next-frame video prediction loss. Both share the same backward graph so the DiT backbone receives gradients from both losses without a redundant forward pass.

During inference (no next_image): the standard encode path is used and no video loss is computed.

Source code in AlphaBrain/model/framework/QwenGR00T.py
def forward(
    self,
    examples: List[dict] = None,
    **kwargs,
) -> Tuple:
    """Run a full training forward pass, with video loss when next_image is available.

    When next_image is available, performs
    a SINGLE DiT forward that simultaneously yields action visual tokens and the
    next-frame video prediction loss.  Both share the same backward graph so the
    DiT backbone receives gradients from both losses without a redundant forward pass.

    During inference (no next_image): the standard encode path is used and no video
    loss is computed.
    """
    batch_images = [example["image"] for example in examples]  # [B, [PIL]]
    instructions = [example["lang"] for example in examples]   # [B, str]
    actions = [example["action"] for example in examples]      # [B, T, action_dim]

    state = [example["state"] for example in examples] if (self.use_state and "state" in examples[0]) else None

    video_loss = None

    # ===================================================================
    # V2 video loss path (training with next_image)
    # Single DiT pass → layer 18 features for action + final output for video
    # ===================================================================
    has_next_images = (
        len(examples) > 0
        and "next_image" in examples[0]
        and "next_image" in examples[0]
    )
    has_visual_encoder = hasattr(self.qwen_vl_interface, "visual_encoder")

    if has_next_images and has_visual_encoder:
        wm_encoder = self.qwen_vl_interface.visual_encoder
        next_images_raw = [example.get("next_image") for example in examples]
        valid_mask = [img is not None for img in next_images_raw]

        if True:  # always go V2 path to avoid NCCL deadlock
            curr_images_flat = [
                imgs[0] if isinstance(imgs, (list, tuple)) else imgs
                for imgs in batch_images
            ]
            dummy_next = [
                next_images_raw[i] if valid_mask[i] else curr_images_flat[i]
                for i in range(len(examples))
            ]

            with torch.autocast("cuda", dtype=torch.bfloat16):
                curr_pv = wm_encoder.preprocess(curr_images_flat)
                next_pv = wm_encoder.preprocess(dummy_next)
                qwenvl_outputs, video_loss_raw = self.qwen_vl_interface.forward_with_video_loss(
                    curr_pv, instructions, next_pv
                )
                last_hidden = qwenvl_outputs.hidden_states[-1]  # [B, L, H]

            valid_count = sum(valid_mask)
            if valid_count == 0:
                video_loss = video_loss_raw * 0.0  # no valid next_image, zero out video loss
            elif valid_count < len(valid_mask):
                scale = len(valid_mask) / valid_count
                video_loss = video_loss_raw * scale
            else:
                video_loss = video_loss_raw

    if not (has_next_images and has_visual_encoder) or video_loss is None:
        # ===============================================================
        # Standard encode path (inference or training without video loss)
        # ===============================================================
        qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(
            images=batch_images, instructions=instructions
        )
        with torch.autocast("cuda", dtype=torch.bfloat16):
            qwenvl_outputs = self.qwen_vl_interface(
                **qwen_inputs,
                output_attentions=False,
                output_hidden_states=True,
                return_dict=True,
            )
            last_hidden = qwenvl_outputs.hidden_states[-1]  # [B, L, H]
        video_loss = None

    # ===================================================================
    # Action Expert Forward and Loss
    # ===================================================================
    with torch.autocast("cuda", dtype=torch.float32):
        actions = torch.tensor(
            np.array(actions), device=last_hidden.device, dtype=last_hidden.dtype
        )  # [B, T_full, action_dim]
        actions_target = actions[:, -(self.future_action_window_size + 1):, :]

        repeated_diffusion_steps = (
            self.config.trainer.get("repeated_diffusion_steps", 4)
            if self.config and self.config.trainer else 4
        )
        actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1)
        last_hidden_repeated = last_hidden.repeat(repeated_diffusion_steps, 1, 1)

        state_repeated = None
        if state is not None:
            state = torch.tensor(
                np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
            )
            expected_state_dim = getattr(
                getattr(getattr(self.config, "framework", None), "action_model", None),
                "state_dim", None,
            )
            if expected_state_dim and state.shape[-1] > expected_state_dim:
                state = state[..., :expected_state_dim]
            if state.ndim == 2:
                state = state.unsqueeze(1)
            state_repeated = state.repeat(repeated_diffusion_steps, 1, 1)

        action_loss = self.action_model(
            last_hidden_repeated, actions_target_repeated, state_repeated
        )

    # ===================================================================
    # Combine losses and build output dict
    # ===================================================================
    result = {"action_loss": action_loss}

    if video_loss is not None:
        result["video_loss"] = video_loss
        result["total_loss"] = action_loss + self._video_loss_weight * video_loss

    return result
predict_action
predict_action(examples: List[dict] = None, batch_images: List = None, instructions: List[str] = None, states=None, return_predicted_frame: bool = False, **kwargs) -> np.ndarray
Steps
  1. Resize images to training resolution (if specified)
  2. Encode with QwenVL (hidden states retained)
  3. Return normalized action trajectory
Supports two input formats
  • examples: List[dict] with keys "image", "lang", "state" (legacy format)
  • batch_images + instructions: direct arguments (consistent with NeuroVLA/QwenOFT)

Returns:

Name Type Description
dict ndarray

normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.

Source code in AlphaBrain/model/framework/QwenGR00T.py
@torch.inference_mode()
def predict_action(
    self,
    examples: List[dict] = None,
    batch_images: List = None,
    instructions: List[str] = None,
    states=None,
    return_predicted_frame: bool = False,
    **kwargs,
) -> np.ndarray:
    """
    Steps:
      1. Resize images to training resolution (if specified)
      2. Encode with QwenVL (hidden states retained)
      6. Return normalized action trajectory

    Supports two input formats:
      - examples: List[dict] with keys "image", "lang", "state" (legacy format)
      - batch_images + instructions: direct arguments (consistent with NeuroVLA/QwenOFT)

    Returns:
        dict:
            normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.
    """
    if examples is not None:
        # Legacy examples format
        if type(examples) is not list:
            examples = [examples]
        batch_images = [to_pil_preserve(example["image"]) for example in examples]
        instructions = [example["lang"] for example in examples]  # [B, str]
        state = [example["state"] for example in examples] if (self.use_state and "state" in examples[0]) else None
    else:
        # Direct batch_images/instructions format (from websocket client)
        assert batch_images is not None and instructions is not None, \
            "Either 'examples' or both 'batch_images' and 'instructions' must be provided"
        if isinstance(batch_images[0][0], np.ndarray):
            batch_images = [[Image.fromarray(img) for img in seq] for seq in batch_images]
        state = states if self.use_state else None

    train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
    if train_obs_image_size:
        batch_images = resize_images(batch_images, target_size=train_obs_image_size)

    # Step 1: QWenVL input format
    qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs(images=batch_images, instructions=instructions)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        qwenvl_outputs = self.qwen_vl_interface(
            **qwen_inputs,
            output_attentions=False,
            output_hidden_states=True,
            return_dict=True,
        )

        # last_hidden_state: [B, seq_len, H]
        last_hidden = qwenvl_outputs.hidden_states[-1]   # [B, L, H]

    state = torch.from_numpy(np.array(state)).to(last_hidden.device, dtype=last_hidden.dtype) if state is not None else None

    # Truncate state to match model's expected state_dim (e.g. LIBERO sends 8-dim, model expects 7-dim)
    if state is not None:
        expected_state_dim = getattr(
            getattr(getattr(self.config, 'framework', None), 'action_model', None),
            'state_dim', None
        )
        if expected_state_dim and state.shape[-1] > expected_state_dim:
            state = state[..., :expected_state_dim]
        # Ensure state is 3D (B, T, state_dim) for action model
        if state.ndim == 2:
            state = state.unsqueeze(1)

    # Step 4: Action Expert Forward
    with torch.autocast("cuda", dtype=torch.float32):
        pred_actions = self.action_model.predict_action(last_hidden, state)  # (B, chunk_len, action_dim)

    normalized_actions = pred_actions.detach().cpu().numpy()
    result = {"normalized_actions": normalized_actions}
    if return_predicted_frame and hasattr(self.qwen_vl_interface, "visual_encoder") and hasattr(self.qwen_vl_interface.visual_encoder, "denoise_future_frame"):
        try:
            predicted_frame = self.predict_future_frame(
                batch_images=batch_images, instructions=instructions,
                num_steps=5, sigma_min=0.002, sigma_max=80.0,
            )
            result["predicted_frame"] = predicted_frame
        except Exception as e:
            import logging
            logging.warning("predict_future_frame failed: %s", e)
    return result
predict_future_frame
predict_future_frame(batch_images: List = None, instructions: List[str] = None, num_steps: int = 5, sigma_min: float = 4.0, sigma_max: float = 80.0) -> np.ndarray

Predict future frame using DiT denoising.

Returns:

Name Type Description
future_frames ndarray

np.ndarray [B, H, W, 3] uint8 predicted future frames

Source code in AlphaBrain/model/framework/QwenGR00T.py
@torch.inference_mode()
def predict_future_frame(
    self,
    batch_images: List = None,
    instructions: List[str] = None,
    num_steps: int = 5,
    sigma_min: float = 4.0,
    sigma_max: float = 80.0,
) -> np.ndarray:
    """Predict future frame using DiT denoising.

    Returns:
        future_frames: np.ndarray [B, H, W, 3] uint8 predicted future frames
    """
    if not hasattr(self.qwen_vl_interface, 'visual_encoder'):
        raise ValueError("predict_future_frame requires world model visual encoder")

    wm_encoder = self.qwen_vl_interface.visual_encoder

    # Preprocess images
    curr_images = [imgs[0] if isinstance(imgs, (list, tuple)) else imgs for imgs in batch_images]
    with torch.autocast("cuda", dtype=torch.bfloat16):
        curr_pv = wm_encoder.preprocess(curr_images)

        # Get text embeddings
        text_embeds = wm_encoder.encode_text(instructions, curr_pv.device) if hasattr(wm_encoder, 'encode_text') else None

        # Encode current frame to latent
        latent_t = wm_encoder.encode_to_latent(curr_pv)

        # Denoise future frame
        future_latent = wm_encoder.denoise_future_frame(
            latent_t, text_embeds, num_steps=num_steps,
            sigma_min=sigma_min, sigma_max=sigma_max,
        )

        # Decode to pixels
        future_video = wm_encoder.decode_latent(future_latent)

    # Convert to uint8 numpy
    future_frames = ((future_video[:, :, 0] + 1) * 127.5).clamp(0, 255)
    future_frames = future_frames.permute(0, 2, 3, 1).to(torch.uint8).cpu().numpy()

    return future_frames

World Model VLA

WorldModelVLA

WorldModelVLA Framework (Phase 2b-A)

Clean-room rename/clone of Qwen_GR00T for world-model-backbone VLA: - Uses WorldModelVLMInterface (not a Qwen VLM) as visual encoder - Reads config.framework.world_model. (not config.framework.qwenvl.) - Attribute renamed: self.qwen_vl_interface -> self.world_model_encoder - Framework-local wrapper: prepare_inputs (still calls interface.build_vlm_inputs under the hood; interface method rename deferred to Phase 3+)

QwenGR00T.py remains the canonical path for Qwen VLMs and is NOT modified.

WorldModelVLA

WorldModelVLA(config: Optional[dict] = None, **kwargs)

Bases: BaseFramework

World-Model-backbone vision-language-action model.

Components
  • WorldModelVLMInterface: Cosmos-Predict2 / V-JEPA2 / WAN2 visual encoder with lightweight text encoder and cross-attention fusion
  • FlowmatchingActionHead (GR00T-N1.5) for future action sequence modeling

Focus: Predict future continuous actions conditioned on images + instruction, optionally with next-frame video loss for joint WM + action training.

Construct all submodules and cache key configuration values.

Parameters:

Name Type Description Default
config Optional[dict]

Hierarchical configuration (OmegaConf/dict) containing framework.world_model. (encoder) and framework.action_model. (FlowMatching head) sections.

None
**kwargs

Reserved for future overrides (unused).

{}
Source code in AlphaBrain/model/framework/WorldModelVLA.py
def __init__(
    self,
    config: Optional[dict] = None,
    **kwargs,
) -> None:
    """
    Construct all submodules and cache key configuration values.

    Args:
        config: Hierarchical configuration (OmegaConf/dict) containing
            framework.world_model.* (encoder) and framework.action_model.*
            (FlowMatching head) sections.
        **kwargs: Reserved for future overrides (unused).
    """
    super().__init__()
    self.config = config
    self.world_model_encoder = get_world_model_encoder(config=self.config)
    # align dims: use world_model hidden_size (post-fusion) if available,
    # else fall back to encoder's internal hidden size.
    wm_cfg = getattr(self.config.framework, "world_model", None)
    if wm_cfg is not None and getattr(wm_cfg, "hidden_size", None) is not None:
        self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = wm_cfg.hidden_size
    else:
        self.config.framework.action_model.diffusion_model_cfg.cross_attention_dim = self.world_model_encoder.model.config.hidden_size

    self.action_model: FlowmatchingActionHead = get_action_model(config=self.config)

    self.future_action_window_size = config.framework.action_model.future_action_window_size
    self.past_action_window_size = config.framework.action_model.past_action_window_size
    self.chunk_len = self.past_action_window_size + 1 + self.future_action_window_size

    # Determine whether state was used during training (controls state_encoder usage)
    self.use_state = getattr(
        getattr(getattr(config, 'datasets', None), 'vla_data', None),
        'include_state', False
    )
    if self.use_state in ["False", False, None, "false", ""]:
        self.use_state = False
    else:
        self.use_state = True
    logger.info(f"[WorldModelVLA] use_state={self.use_state} (from config.datasets.vla_data.include_state)")

    # Freeze state_encoder if state is not used (prevents DeepSpeed gradient deadlock)
    if not self.use_state and hasattr(self, 'action_model'):
        if hasattr(self.action_model, 'state_encoder') and self.action_model.state_encoder is not None:
            self.action_model.state_encoder.requires_grad_(False)
            logger.info("[WorldModelVLA] Froze state_encoder (include_state=false)")

    # Video loss weight (hyperparameter, not a toggle).
    # WM backbones auto-enable video loss via has_visual_encoder check in forward().
    self._video_loss_weight = float(
        getattr(
            getattr(getattr(config, 'framework', None), 'world_model', None),
            'video_loss_weight', 1.0,
        )
    ) if (hasattr(config, 'framework') and hasattr(config.framework, 'world_model')) else 1.0
    logger.info("[WorldModelVLA] video_loss_weight=%.3f", self._video_loss_weight)
prepare_inputs
prepare_inputs(images, instructions)

Framework-local wrapper around the encoder's build_vlm_inputs.

Interface method rename is deferred to Phase 3+; for now the underlying call still hits self.world_model_encoder.build_vlm_inputs(...).

Source code in AlphaBrain/model/framework/WorldModelVLA.py
def prepare_inputs(self, images, instructions):
    """Framework-local wrapper around the encoder's build_vlm_inputs.

    Interface method rename is deferred to Phase 3+; for now the underlying
    call still hits self.world_model_encoder.build_vlm_inputs(...).
    """
    return self.world_model_encoder.build_vlm_inputs(
        images=images, instructions=instructions
    )
forward
forward(examples: List[dict] = None, **kwargs) -> Tuple

Run a full training forward pass, with video loss when next_image is available.

When next_image is available, performs a SINGLE DiT forward that simultaneously yields action visual tokens and the next-frame video prediction loss. Both share the same backward graph so the DiT backbone receives gradients from both losses without a redundant forward pass.

During inference (no next_image): the standard encode path is used and no video loss is computed.

Source code in AlphaBrain/model/framework/WorldModelVLA.py
def forward(
    self,
    examples: List[dict] = None,
    **kwargs,
) -> Tuple:
    """Run a full training forward pass, with video loss when next_image is available.

    When next_image is available, performs a SINGLE DiT forward that
    simultaneously yields action visual tokens and the next-frame video
    prediction loss.  Both share the same backward graph so the DiT backbone
    receives gradients from both losses without a redundant forward pass.

    During inference (no next_image): the standard encode path is used and no
    video loss is computed.
    """
    batch_images = [example["image"] for example in examples]  # [B, [PIL]]
    instructions = [example["lang"] for example in examples]   # [B, str]
    actions = [example["action"] for example in examples]      # [B, T, action_dim]

    state = [example["state"] for example in examples] if (self.use_state and "state" in examples[0]) else None

    video_loss = None

    # ===================================================================
    # V2 video loss path (training with next_image)
    # Single DiT pass -> layer 18 features for action + final output for video
    # ===================================================================
    has_next_images = (
        len(examples) > 0
        and "next_image" in examples[0]
        and "next_image" in examples[0]
    )
    has_visual_encoder = hasattr(self.world_model_encoder, "visual_encoder")

    if has_next_images and has_visual_encoder:
        wm_encoder = self.world_model_encoder.visual_encoder
        next_images_raw = [example.get("next_image") for example in examples]
        valid_mask = [img is not None for img in next_images_raw]

        if True:  # always go V2 path to avoid NCCL deadlock
            curr_images_flat = [
                imgs[0] if isinstance(imgs, (list, tuple)) else imgs
                for imgs in batch_images
            ]
            dummy_next = [
                next_images_raw[i] if valid_mask[i] else curr_images_flat[i]
                for i in range(len(examples))
            ]

            with torch.autocast("cuda", dtype=torch.bfloat16):
                curr_pv = wm_encoder.preprocess(curr_images_flat)
                next_pv = wm_encoder.preprocess(dummy_next)
                wm_outputs, video_loss_raw = self.world_model_encoder.forward_with_video_loss(
                    curr_pv, instructions, next_pv
                )
                last_hidden = wm_outputs.hidden_states[-1]  # [B, L, H]

            valid_count = sum(valid_mask)
            if valid_count == 0:
                video_loss = video_loss_raw * 0.0  # no valid next_image, zero out video loss
            elif valid_count < len(valid_mask):
                scale = len(valid_mask) / valid_count
                video_loss = video_loss_raw * scale
            else:
                video_loss = video_loss_raw

    if not (has_next_images and has_visual_encoder) or video_loss is None:
        # ===============================================================
        # Standard encode path (inference or training without video loss)
        # ===============================================================
        wm_inputs = self.prepare_inputs(
            images=batch_images, instructions=instructions
        )
        with torch.autocast("cuda", dtype=torch.bfloat16):
            wm_outputs = self.world_model_encoder(
                **wm_inputs,
                output_attentions=False,
                output_hidden_states=True,
                return_dict=True,
            )
            last_hidden = wm_outputs.hidden_states[-1]  # [B, L, H]
        video_loss = None

    # ===================================================================
    # Action Expert Forward and Loss
    # ===================================================================
    with torch.autocast("cuda", dtype=torch.float32):
        actions = torch.tensor(
            np.array(actions), device=last_hidden.device, dtype=last_hidden.dtype
        )  # [B, T_full, action_dim]
        actions_target = actions[:, -(self.future_action_window_size + 1):, :]

        repeated_diffusion_steps = (
            self.config.trainer.get("repeated_diffusion_steps", 4)
            if self.config and self.config.trainer else 4
        )
        actions_target_repeated = actions_target.repeat(repeated_diffusion_steps, 1, 1)
        last_hidden_repeated = last_hidden.repeat(repeated_diffusion_steps, 1, 1)

        state_repeated = None
        if state is not None:
            state = torch.tensor(
                np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
            )
            expected_state_dim = getattr(
                getattr(getattr(self.config, "framework", None), "action_model", None),
                "state_dim", None,
            )
            if expected_state_dim and state.shape[-1] > expected_state_dim:
                state = state[..., :expected_state_dim]
            if state.ndim == 2:
                state = state.unsqueeze(1)
            state_repeated = state.repeat(repeated_diffusion_steps, 1, 1)

        action_loss = self.action_model(
            last_hidden_repeated, actions_target_repeated, state_repeated
        )

    # ===================================================================
    # Combine losses and build output dict
    # ===================================================================
    result = {"action_loss": action_loss}

    if video_loss is not None:
        result["video_loss"] = video_loss
        result["total_loss"] = action_loss + self._video_loss_weight * video_loss

    return result
predict_action
predict_action(examples: List[dict] = None, batch_images: List = None, instructions: List[str] = None, states=None, return_predicted_frame: bool = False, **kwargs) -> np.ndarray
Steps
  1. Resize images to training resolution (if specified)
  2. Encode with world model (hidden states retained)
  3. Run FlowMatching action head
  4. Return normalized action trajectory
Supports two input formats
  • examples: List[dict] with keys "image", "lang", "state" (legacy format)
  • batch_images + instructions: direct arguments (consistent with NeuroVLA/QwenOFT)

Returns:

Name Type Description
dict ndarray

normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.

Source code in AlphaBrain/model/framework/WorldModelVLA.py
@torch.inference_mode()
def predict_action(
    self,
    examples: List[dict] = None,
    batch_images: List = None,
    instructions: List[str] = None,
    states=None,
    return_predicted_frame: bool = False,
    **kwargs,
) -> np.ndarray:
    """
    Steps:
      1. Resize images to training resolution (if specified)
      2. Encode with world model (hidden states retained)
      3. Run FlowMatching action head
      4. Return normalized action trajectory

    Supports two input formats:
      - examples: List[dict] with keys "image", "lang", "state" (legacy format)
      - batch_images + instructions: direct arguments (consistent with NeuroVLA/QwenOFT)

    Returns:
        dict:
            normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions.
    """
    if examples is not None:
        # Legacy examples format
        if type(examples) is not list:
            examples = [examples]
        batch_images = [to_pil_preserve(example["image"]) for example in examples]
        instructions = [example["lang"] for example in examples]  # [B, str]
        state = [example["state"] for example in examples] if (self.use_state and "state" in examples[0]) else None
    else:
        # Direct batch_images/instructions format (from websocket client)
        assert batch_images is not None and instructions is not None, \
            "Either 'examples' or both 'batch_images' and 'instructions' must be provided"
        if isinstance(batch_images[0][0], np.ndarray):
            batch_images = [[Image.fromarray(img) for img in seq] for seq in batch_images]
        state = states if self.use_state else None

    train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None)
    if train_obs_image_size:
        batch_images = resize_images(batch_images, target_size=train_obs_image_size)

    # Step 1: World-model input format
    wm_inputs = self.prepare_inputs(images=batch_images, instructions=instructions)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        wm_outputs = self.world_model_encoder(
            **wm_inputs,
            output_attentions=False,
            output_hidden_states=True,
            return_dict=True,
        )

        # last_hidden_state: [B, seq_len, H]
        last_hidden = wm_outputs.hidden_states[-1]   # [B, L, H]

    state = torch.from_numpy(np.array(state)).to(last_hidden.device, dtype=last_hidden.dtype) if state is not None else None

    # Truncate state to match model's expected state_dim (e.g. LIBERO sends 8-dim, model expects 7-dim)
    if state is not None:
        expected_state_dim = getattr(
            getattr(getattr(self.config, 'framework', None), 'action_model', None),
            'state_dim', None
        )
        if expected_state_dim and state.shape[-1] > expected_state_dim:
            state = state[..., :expected_state_dim]
        # Ensure state is 3D (B, T, state_dim) for action model
        if state.ndim == 2:
            state = state.unsqueeze(1)

    # Step 4: Action Expert Forward
    with torch.autocast("cuda", dtype=torch.float32):
        pred_actions = self.action_model.predict_action(last_hidden, state)  # (B, chunk_len, action_dim)

    normalized_actions = pred_actions.detach().cpu().numpy()
    result = {"normalized_actions": normalized_actions}
    if return_predicted_frame and hasattr(self.world_model_encoder, "visual_encoder") and hasattr(self.world_model_encoder.visual_encoder, "denoise_future_frame"):
        try:
            predicted_frame = self.predict_future_frame(
                batch_images=batch_images, instructions=instructions,
                num_steps=5, sigma_min=0.002, sigma_max=80.0,
            )
            result["predicted_frame"] = predicted_frame
        except Exception as e:
            import logging
            logging.warning("predict_future_frame failed: %s", e)
    return result
predict_future_frame
predict_future_frame(batch_images: List = None, instructions: List[str] = None, num_steps: int = 5, sigma_min: float = 4.0, sigma_max: float = 80.0) -> np.ndarray

Predict future frame using DiT denoising.

Returns:

Name Type Description
future_frames ndarray

np.ndarray [B, H, W, 3] uint8 predicted future frames

Source code in AlphaBrain/model/framework/WorldModelVLA.py
@torch.inference_mode()
def predict_future_frame(
    self,
    batch_images: List = None,
    instructions: List[str] = None,
    num_steps: int = 5,
    sigma_min: float = 4.0,
    sigma_max: float = 80.0,
) -> np.ndarray:
    """Predict future frame using DiT denoising.

    Returns:
        future_frames: np.ndarray [B, H, W, 3] uint8 predicted future frames
    """
    if not hasattr(self.world_model_encoder, 'visual_encoder'):
        raise ValueError("predict_future_frame requires world model visual encoder")

    wm_encoder = self.world_model_encoder.visual_encoder

    # Preprocess images
    curr_images = [imgs[0] if isinstance(imgs, (list, tuple)) else imgs for imgs in batch_images]
    with torch.autocast("cuda", dtype=torch.bfloat16):
        curr_pv = wm_encoder.preprocess(curr_images)

        # Get text embeddings
        text_embeds = wm_encoder.encode_text(instructions, curr_pv.device) if hasattr(wm_encoder, 'encode_text') else None

        # Encode current frame to latent
        latent_t = wm_encoder.encode_to_latent(curr_pv)

        # Denoise future frame
        future_latent = wm_encoder.denoise_future_frame(
            latent_t, text_embeds, num_steps=num_steps,
            sigma_min=sigma_min, sigma_max=sigma_max,
        )

        # Decode to pixels
        future_video = wm_encoder.decode_latent(future_latent)

    # Convert to uint8 numpy
    future_frames = ((future_video[:, :, 0] + 1) * 127.5).clamp(0, 255)
    future_frames = future_frames.permute(0, 2, 3, 1).to(torch.uint8).cpu().numpy()

    return future_frames