Skip to content

Dataloader

Source path: AlphaBrain/dataloader/

Aggregates the data pipelines used by VLA training: LeRobot / PaliGemma / VLM / Cosmos / GR00T / Qwen-VL LLaVA-JSON.


Top-level factory

dataloader

save_dataset_statistics

save_dataset_statistics(dataset_statistics, run_dir)

Saves a dataset_statistics.json file.

Source code in AlphaBrain/dataloader/__init__.py
def save_dataset_statistics(dataset_statistics, run_dir):
    """Saves a `dataset_statistics.json` file."""
    out_path = run_dir / "dataset_statistics.json"
    with open(out_path, "w") as f_json:
        for _, stats in dataset_statistics.items():
            for k in stats["action"].keys():
                if isinstance(stats["action"][k], np.ndarray):
                    stats["action"][k] = stats["action"][k].tolist()
            if "proprio" in stats:
                for k in stats["proprio"].keys():
                    if isinstance(stats["proprio"][k], np.ndarray):
                        stats["proprio"][k] = stats["proprio"][k].tolist()
            if "num_trajectories" in stats:
                if isinstance(stats["num_trajectories"], np.ndarray):
                    stats["num_trajectories"] = stats["num_trajectories"].item()
            if "num_transitions" in stats:
                if isinstance(stats["num_transitions"], np.ndarray):
                    stats["num_transitions"] = stats["num_transitions"].item()
        json.dump(dataset_statistics, f_json, indent=2)
    logger.info(f"Saved dataset statistics file at path {out_path}")

LeRobot datasets

lerobot_datasets

make_LeRobotSingleDataset

make_LeRobotSingleDataset(data_root_dir: Path | str, data_name: str, robot_type: str, delete_pause_frame: bool = False, data_cfg: dict | None = None) -> LeRobotSingleDataset

Make a LeRobotSingleDataset object.

:param data_root_dir: The root directory of the dataset. :param data_name: The name of the dataset. :param robot_type: The robot type config to use. :param crop_obs_camera: Whether to crop the observation camera images. :return: A LeRobotSingleDataset object.

Source code in AlphaBrain/dataloader/lerobot_datasets.py
def make_LeRobotSingleDataset(
    data_root_dir: Path | str,
    data_name: str,
    robot_type: str,
    delete_pause_frame: bool = False,
    data_cfg: dict | None = None,
) -> LeRobotSingleDataset:
    """
    Make a LeRobotSingleDataset object.

    :param data_root_dir: The root directory of the dataset.
    :param data_name: The name of the dataset.
    :param robot_type: The robot type config to use.
    :param crop_obs_camera: Whether to crop the observation camera images.
    :return: A LeRobotSingleDataset object.
    """

    data_config = ROBOT_TYPE_CONFIG_MAP[robot_type]
    modality_config = data_config.modality_config()
    transforms = data_config.transform()
    data_root_dir = Path(data_root_dir)
    dataset_path = Path(data_name) if Path(data_name).is_absolute() else data_root_dir / data_name
    if robot_type not in ROBOT_TYPE_TO_EMBODIMENT_TAG:
        print(f"Warning: Robot type {robot_type} not found in ROBOT_TYPE_TO_EMBODIMENT_TAG, using {EmbodimentTag.NEW_EMBODIMENT} as default")
        embodiment_tag = EmbodimentTag.NEW_EMBODIMENT
    else:
        embodiment_tag = ROBOT_TYPE_TO_EMBODIMENT_TAG[robot_type]

    video_backend = data_cfg.get("video_backend", "decord") if data_cfg else "torchvision_av"
    dataset = LeRobotSingleDataset(
        dataset_path=dataset_path,
        modality_configs=modality_config,
        transforms=transforms,
        embodiment_tag=embodiment_tag,
        video_backend=video_backend, # decord is more efficiency | torchvision_av for video.av1
        delete_pause_frame=delete_pause_frame,
        data_cfg=data_cfg,
    )
    # Keep mixture members unique even when many datasets end with the same leaf directory name, e.g. `lerobot`.
    dataset._dataset_name = str(data_name)
    return dataset

get_vla_dataset

get_vla_dataset(data_cfg: dict, mode: str = 'train', balance_dataset_weights: bool = False, balance_trajectory_weights: bool = False, seed: int = 42, **kwargs: dict) -> LeRobotMixtureDataset

Get a LeRobotMixtureDataset object.

Source code in AlphaBrain/dataloader/lerobot_datasets.py
def get_vla_dataset(
    data_cfg: dict,
    mode: str = "train",
    balance_dataset_weights: bool = False,
    balance_trajectory_weights: bool = False,
    seed: int = 42,
    **kwargs: dict,
) -> LeRobotMixtureDataset:
    """
    Get a LeRobotMixtureDataset object.
    """
    data_root_dir = data_cfg.data_root_dir
    dataset_mix = data_cfg.dataset_mix
    delete_pause_frame = data_cfg.get("delete_pause_frame", False)
    mixture_spec = DATASET_NAMED_MIXTURES[dataset_mix]
    data_root_dir = Path(data_root_dir)
    included_datasets, filtered_mixture_spec = set(), []
    for d_name, d_weight, robot_type in mixture_spec:
        expanded_entries = _expand_mixture_entry(data_root_dir, d_name, robot_type)
        if not expanded_entries:
            print(f"Warning: No datasets matched `{d_name}` under `{data_root_dir}`")
        for expanded_name, expanded_robot_type in expanded_entries:
            dataset_key = (expanded_name, expanded_robot_type)
            if dataset_key in included_datasets:
                print(f"Skipping Duplicate Dataset: `{(expanded_name, d_weight, expanded_robot_type)}`")
                continue

            included_datasets.add(dataset_key)
            filtered_mixture_spec.append((expanded_name, d_weight, expanded_robot_type))

    dataset_mixture = []
    for d_name, d_weight, robot_type in filtered_mixture_spec:
        dataset_mixture.append((make_LeRobotSingleDataset(data_root_dir, d_name, robot_type, delete_pause_frame=delete_pause_frame, data_cfg=data_cfg), d_weight))

    return LeRobotMixtureDataset(
        dataset_mixture,
        mode=mode,
        balance_dataset_weights=balance_dataset_weights,
        balance_trajectory_weights=balance_trajectory_weights,
        seed=seed,
        data_cfg=data_cfg,
        **kwargs,
    )

PaliGemma datasets

paligemma_datasets

PaliGemmaOFT Data Pipeline

Adapts VLAE's existing LeRobot data loading to PaliGemmaOFT format. Reuses the existing lerobot_datasets.py infrastructure, adding Pi0-specific transforms.

Pi0 expects
  • images: dict of {camera_name: [B, H, W, 3] uint8 tensors}
  • image_masks: dict of {camera_name: [B] bool tensors}
  • state: [B, state_dim] float32
  • tokenized_prompt: [B, max_token_len] int32
  • tokenized_prompt_mask: [B, max_token_len] bool
  • actions: [B, action_horizon, action_dim] float32

Pi0DataConfig dataclass

Pi0DataConfig(image_resolution: tuple = (224, 224), max_token_len: int = 200, action_horizon: int = 50, action_dim: int = 7, camera_names: tuple = ('image_0',), include_state: bool = True, state_dim: int = 7)

Configuration for Pi0-specific data processing.

Pi0DataTransform

Pi0DataTransform(config: Pi0DataConfig, tokenizer=None)

Transform VLAE LeRobot data samples into PaliGemmaOFT format.

Input (from LeRobot dataloader): dict with keys: image (List[PIL.Image]), lang (str), action (np.ndarray), state (np.ndarray)

Output (for PaliGemmaOFT.forward()): dict with same keys, but images resized and ready for Pi0 processing

Source code in AlphaBrain/dataloader/paligemma_datasets.py
def __init__(self, config: Pi0DataConfig, tokenizer=None):
    self.config = config
    self.tokenizer = tokenizer  # PaliGemma/Gemma tokenizer

Pi0DatasetWrapper

Pi0DatasetWrapper(base_dataset, transform)

Wraps a LeRobot dataset with Pi0-specific transforms.

Source code in AlphaBrain/dataloader/paligemma_datasets.py
def __init__(self, base_dataset, transform):
    self.base_dataset = base_dataset
    self.transform = transform

get_pi0_dataset

get_pi0_dataset(data_cfg, mode='train', **kwargs)

Get dataset for PaliGemmaOFT training.

Reuses VLAE's existing LeRobot data loading, wrapping it with Pi0-specific transforms.

Parameters:

Name Type Description Default
data_cfg

dataset config (same as used by other VLAE frameworks)

required
mode

"train" or "eval"

'train'

Returns:

Type Description

dataset wrapped with Pi0DataTransform

Source code in AlphaBrain/dataloader/paligemma_datasets.py
def get_pi0_dataset(data_cfg, mode="train", **kwargs):
    """
    Get dataset for PaliGemmaOFT training.

    Reuses VLAE's existing LeRobot data loading, wrapping it with Pi0-specific transforms.

    Args:
        data_cfg: dataset config (same as used by other VLAE frameworks)
        mode: "train" or "eval"

    Returns:
        dataset wrapped with Pi0DataTransform
    """
    from AlphaBrain.dataloader.lerobot_datasets import get_vla_dataset

    # Override action_indices in LIBERO data config if action_horizon > default
    action_horizon = getattr(data_cfg, 'action_horizon', 50)
    from AlphaBrain.dataloader.gr00t_lerobot.data_config import ROBOT_TYPE_CONFIG_MAP
    libero_cfg = ROBOT_TYPE_CONFIG_MAP.get("libero_franka", None)
    if libero_cfg is not None:
        if action_horizon > 8:  # default LIBERO action_indices is range(8)
            libero_cfg.action_indices = list(range(action_horizon))
            logger.info(f"[pi0_data] Overriding action_indices to range({action_horizon})")

        # Skip data-level q99 normalization — Pi0 does MEAN_STD normalization in model
        skip_action_norm = getattr(data_cfg, 'skip_action_norm', True)
        if skip_action_norm:
            from AlphaBrain.dataloader.gr00t_lerobot.transform.state_action import StateActionToTensor
            from AlphaBrain.dataloader.gr00t_lerobot.transform.base import ComposedModalityTransform
            # Override transform to only do tensor conversion, no q99 normalization
            original_transform = libero_cfg.transform
            def raw_transform(self=libero_cfg):
                transforms = [StateActionToTensor(apply_to=self.action_keys)]
                return ComposedModalityTransform(transforms=transforms)
            libero_cfg.transform = raw_transform
            logger.info("[pi0_data] Disabled data-level action normalization (model handles MEAN_STD)")

    # Get the base LeRobot dataset
    base_dataset = get_vla_dataset(data_cfg, mode=mode, **kwargs)

    # Create Pi0 transform config from data_cfg
    pi0_config = Pi0DataConfig(
        action_horizon=getattr(data_cfg, 'action_horizon', 50),
        action_dim=getattr(data_cfg, 'action_dim', 7),
        include_state=getattr(data_cfg, 'include_state', True),
        state_dim=getattr(data_cfg, 'state_dim', 7),
    )

    transform = Pi0DataTransform(config=pi0_config)

    # Wrap the dataset with Pi0 transforms
    return Pi0DatasetWrapper(base_dataset, transform)

VLM datasets

vlm_datasets

LazySupervisedDataset

LazySupervisedDataset(tokenizer: PreTrainedTokenizer, data_args)

Bases: Dataset

Dataset for supervised fine-tuning.

Source code in AlphaBrain/dataloader/vlm_datasets.py
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args):
    super(LazySupervisedDataset, self).__init__()

    dataset = data_args.dataset_use.split(",")
    dataset_list = data_list(dataset)
    rank0_print(f"Loading datasets: {dataset_list}")
    self.video_max_total_pixels = getattr(data_args, "video_max_total_pixels", 1664 * 28 * 28)
    self.video_min_total_pixels = getattr(data_args, "video_min_total_pixels", 256 * 28 * 28)
    self.model_type = data_args.model_type
    if data_args.model_type == "qwen2.5vl":
        self.get_rope_index = get_rope_index_25
    else:
        self.get_rope_index = get_rope_index_2

    list_data_dict = []

    for data in dataset_list:
        file_format = data["annotation_path"].split(".")[-1]
        if file_format == "jsonl":
            annotations = read_jsonl(data["annotation_path"])
        else:
            annotations = json.load(open(data["annotation_path"], "r"))
        sampling_rate = data.get("sampling_rate", 1.0)
        if sampling_rate < 1.0:
            annotations = random.sample(annotations, int(len(annotations) * sampling_rate))
            print(f"sampling {len(annotations)} examples from dataset {data}")
        else:
            rank0_print(f"dataset name: {data}")
        for ann in annotations:
            if data["data_path"] != "":
                ann["data_path"] = data["data_path"]
            elif "raw_data" in ann.keys():
                ann["data_path"] = ann["raw_data"]["data_root"]
        list_data_dict += annotations

    list_data_dict = self.pre_filter_long_case(list_data_dict, max_words=tokenizer.max_len_single_sentence)
    random.shuffle(list_data_dict)  # Randomly shuffle the data for training

    self.tokenizer = tokenizer
    self.list_data_dict = list_data_dict
    self.data_args = data_args

    rank0_print(f"Total training samples: {len(self.list_data_dict)}")
    rank0_print("Formatting inputs...Skip in lazy mode")
pre_filter_long_case
pre_filter_long_case(list_data_dict, max_words=1024)

filter out conversations with total words exceeding max_words

Source code in AlphaBrain/dataloader/vlm_datasets.py
def pre_filter_long_case(self, list_data_dict, max_words=1024):
    """filter out conversations with total words exceeding max_words"""

    def count_total_words(convs):
        total = 0
        for entry in convs:
            value = entry.get("value", "")
            total += len(value.strip().split())
        return total

    return [item for item in list_data_dict if count_total_words(item.get("conversations", [])) <= max_words]

DataCollatorForSupervisedDataset dataclass

DataCollatorForSupervisedDataset(tokenizer: PreTrainedTokenizer)

Bases: object

Collate examples for supervised fine-tuning.

FlattenedDataCollatorForSupervisedDataset dataclass

FlattenedDataCollatorForSupervisedDataset(tokenizer: PreTrainedTokenizer)

Bases: DataCollatorForSupervisedDataset

Collate examples into packed sequence with multi-modal support.

make_supervised_data_module

make_supervised_data_module(tokenizer: PreTrainedTokenizer, data_args) -> Dict

Make dataset and collator for supervised fine-tuning.

Source code in AlphaBrain/dataloader/vlm_datasets.py
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    # load training dataset
    train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)

    # load evaluation dataset (if specified eval dataset path)
    eval_dataset = None
    if hasattr(data_args, "eval_dataset") and data_args.eval_dataset:
        eval_data_args = copy.deepcopy(data_args)
        eval_data_args.dataset_use = data_args.eval_dataset
        eval_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=eval_data_args)

    # select appropriate collator based on whether data needs to be flattened
    if data_args.data_flatten:
        data_collator = FlattenedDataCollatorForSupervisedDataset(tokenizer=tokenizer)
    else:
        data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

    return dict(
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
    )

Cosmos datasets

cosmos_datasets

LIBERO-Cosmos-Policy dataset loader for AlphaBrain training.

Dataset format

success_only// — demo HDF5 files (data/demo_X/...) all_episodes/ — rollout HDF5 files (flat structure) t5_embeddings.pkl — T5 text embeddings dict dataset_statistics.json — min/max normalization stats

CosmosLIBERODataset

CosmosLIBERODataset(data_dir: str, chunk_size: int = 16, final_image_size: int = 224, num_duplicates_per_image: int = 4, demonstration_sampling_prob: float = 0.5, success_rollout_sampling_prob: float = 0.5, use_image_aug: bool = True, use_stronger_image_aug: bool = True, normalize_actions: bool = True, normalize_proprio: bool = True, gamma: float = 0.99)

Bases: Dataset

Dataset for LIBERO-Cosmos-Policy format.

Loads demo data (success_only/) eagerly and rollout data (all_episodes/) lazily. Returns samples compatible with the cosmos-policy training pipeline.

Source code in AlphaBrain/dataloader/cosmos_datasets.py
def __init__(
    self,
    data_dir: str,
    chunk_size: int = 16,
    final_image_size: int = 224,
    num_duplicates_per_image: int = 4,
    demonstration_sampling_prob: float = 0.5,
    success_rollout_sampling_prob: float = 0.5,
    use_image_aug: bool = True,
    use_stronger_image_aug: bool = True,
    normalize_actions: bool = True,
    normalize_proprio: bool = True,
    gamma: float = 0.99,
):
    self.data_dir = data_dir
    self.chunk_size = chunk_size
    self.final_image_size = final_image_size
    self.num_duplicates_per_image = num_duplicates_per_image
    self.demonstration_sampling_prob = demonstration_sampling_prob
    self.success_rollout_sampling_prob = success_rollout_sampling_prob
    self.use_image_aug = use_image_aug
    self.use_stronger_image_aug = use_stronger_image_aug
    self.normalize_actions = normalize_actions
    self.normalize_proprio = normalize_proprio
    self.gamma = gamma

    # Paths
    self.demo_dir = os.path.join(data_dir, "success_only")
    self.rollout_dir = os.path.join(data_dir, "all_episodes")
    # t5_embeddings.pkl and dataset_statistics.json live under success_only/
    t5_path = os.path.join(data_dir, "success_only", "t5_embeddings.pkl")
    stats_path = os.path.join(data_dir, "success_only", "dataset_statistics.json")

    # Load T5 embeddings
    with open(t5_path, "rb") as f:
        self.t5_text_embeddings = pickle.load(f)

    # Load normalization statistics
    with open(stats_path, "r") as f:
        json_stats = json.load(f)
    self.dataset_stats = {k: np.array(v) for k, v in json_stats.items()}

    # Storage
    self.data = {}           # episode_idx -> demo episode dict
    self.num_episodes = 0
    self.num_steps = 0

    self.rollout_episode_metadata = {}   # episode_idx -> metadata dict
    self.rollout_num_episodes = 0

    # Load demos eagerly
    self._load_demos()

    # Build demo step index mapping
    self._build_demo_step_index_mapping()

    # Load rollout metadata lazily
    self._load_rollout_metadata()

    # Build rollout step index mapping
    self._build_rollout_step_index_mapping()

    # Calculate epoch structure
    self._calculate_epoch_structure()

cosmos_collate_fn

cosmos_collate_fn(batch)

Default collate; tensors are stacked, scalars become tensors.

Source code in AlphaBrain/dataloader/cosmos_datasets.py
def cosmos_collate_fn(batch):
    """Default collate; tensors are stacked, scalars become tensors."""
    from torch.utils.data.dataloader import default_collate
    return default_collate(batch)

GR00T LeRobot subpackage

gr00t_lerobot

data_config

RobommePandaDataConfig

Data config for RoboMME benchmark with Panda robot. 8-dim action/state, image + wrist_image (stored as image bytes in parquet).

datasets

In this file, we define 3 types of datasets: 1. LeRobotSingleDataset: a single dataset for a given embodiment tag 2. LeRobotMixtureDataset: a mixture of datasets for a given list of embodiment tags 3. CachedLeRobotSingleDataset: a single dataset for a given embodiment tag, with caching for the video frames

See scripts/load_dataset.py for examples on how to use these datasets.

ModalityConfig

Bases: BaseModel

Configuration for a modality.

delta_indices instance-attribute
delta_indices: list[int]

Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices.

modality_keys instance-attribute
modality_keys: list[str]

The keys to load for the modality in the dataset.

LeRobotSingleDataset
LeRobotSingleDataset(dataset_path: Path | str, modality_configs: dict[str, ModalityConfig], embodiment_tag: str | EmbodimentTag, video_backend: str = 'decord', video_backend_kwargs: dict | None = None, transforms: ComposedModalityTransform | None = None, delete_pause_frame: bool = False, data_cfg=None, **kwargs)

Bases: Dataset

Base dataset class for LeRobot that supports sharding.

Initialize the dataset.

Parameters:

Name Type Description Default
dataset_path Path | str

The path to the dataset.

required
modality_configs dict[str, ModalityConfig]

The configuration for each modality. The keys are the modality names, and the values are the modality configurations. See ModalityConfig for more details.

required
video_backend str

Backend for video reading.

'decord'
video_backend_kwargs dict

Keyword arguments for the video backend when initializing the video reader.

None
transforms ComposedModalityTransform

The transforms to apply to the dataset.

None
embodiment_tag EmbodimentTag

Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment"

required
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def __init__(
    self,
    dataset_path: Path | str,
    modality_configs: dict[str, ModalityConfig],
    embodiment_tag: str | EmbodimentTag,
    video_backend: str = "decord",
    video_backend_kwargs: dict | None = None,
    transforms: ComposedModalityTransform | None = None,
    delete_pause_frame: bool = False,
    data_cfg = None,
    **kwargs,
):
    """
    Initialize the dataset.

    Args:
        dataset_path (Path | str): The path to the dataset.
        modality_configs (dict[str, ModalityConfig]): The configuration for each modality. The keys are the modality names, and the values are the modality configurations.
            See `ModalityConfig` for more details.
        video_backend (str): Backend for video reading.
        video_backend_kwargs (dict): Keyword arguments for the video backend when initializing the video reader.
        transforms (ComposedModalityTransform): The transforms to apply to the dataset.
        embodiment_tag (EmbodimentTag): Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment"
    """
    # first check if the path directory exists
    self.data_cfg = data_cfg
    if not Path(dataset_path).exists():
        raise FileNotFoundError(f"Dataset path {dataset_path} does not exist")
    # indict letobot version
    self._lerobot_version =  self.data_cfg.get("lerobot_version", "v2.0") #self._indict_lerobot_version(**kwargs)

    self._action_mode = None
    self._action_mode_state_map = {}
    self._action_mode_apply_keys = None

    self.delete_pause_frame = delete_pause_frame

    self.modality_configs = modality_configs
    self.video_backend = video_backend
    self.video_backend_kwargs = video_backend_kwargs if video_backend_kwargs is not None else {}
    self.transforms = (
        transforms if transforms is not None else ComposedModalityTransform(transforms=[])
    )

    self._dataset_path = Path(dataset_path)
    self._dataset_name = self._dataset_path.name
    if isinstance(embodiment_tag, EmbodimentTag):
        self.tag = embodiment_tag.value
    else:
        self.tag = embodiment_tag

    self._init_action_mode()
    self._metadata = self._get_metadata(EmbodimentTag(self.tag))

    # LeRobot-specific config
    self._lerobot_modality_meta = self._get_lerobot_modality_meta()
    self._lerobot_info_meta = self._get_lerobot_info_meta()
    self._data_path_pattern = self._get_data_path_pattern()
    self._video_path_pattern = self._get_video_path_pattern()
    self._chunk_size = self._get_chunk_size()
    self._tasks = self._get_tasks()
    # self._episodes = self._get_episode_info() # TODO why we need this func
    self.curr_traj_data = None
    self.curr_traj_id = None

    # Filter config: skip frames where filter_column=True (e.g., is_demo for robomme)
    self._filter_column = None
    if self.data_cfg and self.data_cfg.get("filter_column", None):
        self._filter_column = self.data_cfg["filter_column"]
    self._traj_row_masks = {}  # trajectory_id -> boolean mask of valid rows

    self._trajectory_ids, self._trajectory_lengths = self._get_trajectories()
    self._modality_keys = self._get_modality_keys()
    self._delta_indices = self._get_delta_indices()
    self._all_steps = self._get_all_steps()
    self.set_transforms_metadata(self.metadata)
    self.set_epoch(0)

    print(f"Initialized dataset {self.dataset_name} with {embodiment_tag}")


    # Check if the dataset is valid
    self._check_integrity()
dataset_path property
dataset_path: Path

The path to the dataset that contains the METADATA_FILENAME file.

metadata property
metadata: DatasetMetadata

The metadata for the dataset, loaded from metadata.json in the dataset directory

trajectory_ids property
trajectory_ids: ndarray

The trajectory IDs in the dataset, stored as a 1D numpy array of strings.

trajectory_lengths property
trajectory_lengths: ndarray

The trajectory lengths in the dataset, stored as a 1D numpy array of integers. The order of the lengths is the same as the order of the trajectory IDs.

all_steps property
all_steps: list[tuple[int, int]]

The trajectory IDs and base indices for all steps in the dataset. Example: self.trajectory_ids: [0, 1, 2] self.trajectory_lengths: [3, 2, 4] return: [ ("traj_0", 0), ("traj_0", 1), ("traj_0", 2), ("traj_1", 0), ("traj_1", 1), ("traj_2", 0), ("traj_2", 1), ("traj_2", 2), ("traj_2", 3) ]

modality_keys property
modality_keys: dict

The modality keys for the dataset. The keys are the modality names, and the values are the keys for each modality.

{

"video": ["video.image_side_0", "video.image_side_1"], "state": ["state.eef_position", "state.eef_rotation"], "action": ["action.eef_position", "action.eef_rotation"], "language": ["language.human.task"], "timestamp": ["timestamp"], "reward": ["reward"],

}

delta_indices property
delta_indices: dict[str, ndarray]

The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key.

dataset_name property
dataset_name: str

The name of the dataset.

lerobot_modality_meta property
lerobot_modality_meta: LeRobotModalityMetadata

The metadata for the LeRobot dataset.

lerobot_info_meta property
lerobot_info_meta: dict

The metadata for the LeRobot dataset.

data_path_pattern property
data_path_pattern: str

The path pattern for the LeRobot dataset.

video_path_pattern property
video_path_pattern: str

The path pattern for the LeRobot dataset.

chunk_size property
chunk_size: int

The chunk size for the LeRobot dataset.

tasks property
tasks: DataFrame

The tasks for the dataset.

set_transforms_metadata
set_transforms_metadata(metadata: DatasetMetadata)

Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def set_transforms_metadata(self, metadata: DatasetMetadata):
    """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values."""
    self.transforms.set_metadata(metadata)
set_epoch
set_epoch(epoch: int)

Set the epoch for the dataset.

Parameters:

Name Type Description Default
epoch int

The epoch to set.

required
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def set_epoch(self, epoch: int):
    """Set the epoch for the dataset.

    Args:
        epoch (int): The epoch to set.
    """
    self.epoch = epoch
get_step_data
get_step_data(trajectory_id: int, base_index: int) -> dict

Get the RAW data for a single step in a trajectory. No transforms are applied.

Parameters:

Name Type Description Default
trajectory_id int

The name of the trajectory.

required
base_index int

The base step index in the trajectory.

required

Returns:

Name Type Description
dict dict

The RAW data for the step.

Example return

{ "video": { "video.image_side_0": [B, T, H, W, C], "video.image_side_1": [B, T, H, W, C], }, "state": { "state.eef_position": [B, T, state_dim], "state.eef_rotation": [B, T, state_dim], }, "action": { "action.eef_position": [B, T, action_dim], "action.eef_rotation": [B, T, action_dim], }, }

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_step_data(self, trajectory_id: int, base_index: int) -> dict:
    """Get the RAW data for a single step in a trajectory. No transforms are applied.

    Args:
        trajectory_id (int): The name of the trajectory.
        base_index (int): The base step index in the trajectory.

    Returns:
        dict: The RAW data for the step.

    Example return:
        {
            "video": {
                "video.image_side_0": [B, T, H, W, C],
                "video.image_side_1": [B, T, H, W, C],
            },
            "state": {
                "state.eef_position": [B, T, state_dim],
                "state.eef_rotation": [B, T, state_dim],
            },
            "action": {
                "action.eef_position": [B, T, action_dim],
                "action.eef_rotation": [B, T, action_dim],
            },
        }
    """
    data = {}
    # Get the data for all modalities # just for action base data
    self.curr_traj_data = self.get_trajectory_data(trajectory_id)
    # TODO: The logic below is poorly implemented. Data reading should be directly based on curr_traj_data.
    for modality in self.modality_keys:
        # Get the data corresponding to each key in the modality
        for key in self.modality_keys[modality]:
            data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index)
    data = self._apply_action_mode(data)
    return data
get_trajectory_data
get_trajectory_data(trajectory_id: int) -> pd.DataFrame

Get the data for a trajectory.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_trajectory_data(self, trajectory_id: int) -> pd.DataFrame:
    """Get the data for a trajectory."""
    if self._lerobot_version == "v2.0":

        if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None:
            return self.curr_traj_data
        else:
            chunk_index = self.get_episode_chunk(trajectory_id)
            parquet_path = self.dataset_path / self.data_path_pattern.format(
                episode_chunk=chunk_index, episode_index=trajectory_id
            )
            assert parquet_path.exists(), f"Parquet file not found at {parquet_path}"
            df = _safe_read_parquet(parquet_path)
            # Apply filter (e.g., remove is_demo=True rows)
            if trajectory_id in self._traj_row_masks:
                mask = self._traj_row_masks[trajectory_id]
                df = df[mask].reset_index(drop=True)
            return df
    elif self._lerobot_version == "v3.0":
        return self.get_trajectory_data_lerobot_v3(trajectory_id)
get_trajectory_data_lerobot_v3
get_trajectory_data_lerobot_v3(trajectory_id: int) -> pd.DataFrame

Get the data for a trajectory from lerobot v3.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_trajectory_data_lerobot_v3(self, trajectory_id: int) -> pd.DataFrame:
    """Get the data for a trajectory from lerobot v3."""
    if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None:
        return self.curr_traj_data
    else: #TODO check detail later
        episode_meta = self.trajectory_ids_to_metadata[trajectory_id]
        chunk_index = episode_meta["data/chunk_index"]
        file_index = self.get_episode_file_index(trajectory_id)
        # file_from_index = self.get_episode_file_from_index(trajectory_id)


        parquet_path = self.dataset_path / self.data_path_pattern.format(
            chunk_index=chunk_index, file_index=file_index
        )
        assert parquet_path.exists(), f"Parquet file not found at {parquet_path}"
        file_data = _safe_read_parquet(parquet_path)

        # filter by trajectory_id
        episode_data = file_data.loc[file_data["episode_index"] == trajectory_id].copy()
        return episode_data
get_trajectory_index
get_trajectory_index(trajectory_id: int) -> int

Get the index of the trajectory in the dataset by the trajectory ID. This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID.

Parameters:

Name Type Description Default
trajectory_id str

The ID of the trajectory.

required

Returns:

Name Type Description
int int

The index of the trajectory in the dataset.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_trajectory_index(self, trajectory_id: int) -> int:
    """Get the index of the trajectory in the dataset by the trajectory ID.
    This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID.

    Args:
        trajectory_id (str): The ID of the trajectory.

    Returns:
        int: The index of the trajectory in the dataset.
    """
    trajectory_indices = np.where(self.trajectory_ids == trajectory_id)[0]
    if len(trajectory_indices) != 1:
        raise ValueError(
            f"Error finding trajectory index for {trajectory_id}, found {trajectory_indices=}"
        )
    return trajectory_indices[0]
get_episode_chunk
get_episode_chunk(ep_index: int) -> int

Get the chunk index for an episode index.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_episode_chunk(self, ep_index: int) -> int:
    """Get the chunk index for an episode index."""
    return ep_index // self.chunk_size
get_episode_file_index
get_episode_file_index(ep_index: int) -> int

Get the file index for an episode index.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_episode_file_index(self, ep_index: int) -> int:
    """Get the file index for an episode index."""
    episode_meta = self.trajectory_ids_to_metadata[ep_index]
    return episode_meta["data/file_index"]
get_episode_file_from_index
get_episode_file_from_index(ep_index: int) -> int

Get the file from index for an episode index.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_episode_file_from_index(self, ep_index: int) -> int:
    """Get the file from index for an episode index."""
    episode_meta = self.trajectory_ids_to_metadata[ep_index]
    return episode_meta["data/file_from_index"]
retrieve_data_and_pad
retrieve_data_and_pad(array: ndarray, step_indices: ndarray, max_length: int, padding_strategy: str = 'first_last') -> np.ndarray

Retrieve the data from the dataset and pad it if necessary. Args: array (np.ndarray): The array to retrieve the data from. step_indices (np.ndarray): The step indices to retrieve the data for. max_length (int): The maximum length of the data. padding_strategy (str): The padding strategy, either "first" or "last".

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def retrieve_data_and_pad(
    self,
    array: np.ndarray,
    step_indices: np.ndarray,
    max_length: int,
    padding_strategy: str = "first_last",
) -> np.ndarray:
    """Retrieve the data from the dataset and pad it if necessary.
    Args:
        array (np.ndarray): The array to retrieve the data from.
        step_indices (np.ndarray): The step indices to retrieve the data for.
        max_length (int): The maximum length of the data.
        padding_strategy (str): The padding strategy, either "first" or "last".
    """
    # Get the padding indices
    front_padding_indices = step_indices < 0
    end_padding_indices = step_indices >= max_length
    padding_positions = np.logical_or(front_padding_indices, end_padding_indices)
    # Retrieve the data with the non-padding indices
    # If there exists some padding, Given T step_indices, the shape of the retrieved data will be (T', ...) where T' < T
    raw_data = array[step_indices[~padding_positions]]
    assert isinstance(raw_data, np.ndarray), f"{type(raw_data)=}"
    # This is the shape of the output, (T, ...)
    if raw_data.ndim == 1:
        expected_shape = (len(step_indices),)
    else:
        expected_shape = (len(step_indices), *array.shape[1:])

    # Pad the data
    output = np.zeros(expected_shape)
    # Assign the non-padded data
    output[~padding_positions] = raw_data
    # If there exists some padding, pad the data
    if padding_positions.any():
        if padding_strategy == "first_last":
            # Use first / last step data to pad
            front_padding_data = array[0]
            end_padding_data = array[-1]
            output[front_padding_indices] = front_padding_data
            output[end_padding_indices] = end_padding_data
        elif padding_strategy == "zero":
            # Use zero padding
            output[padding_positions] = 0
        else:
            raise ValueError(f"Invalid padding strategy: {padding_strategy}")
    return output
get_video
get_video(trajectory_id: int, key: str, base_index: int) -> np.ndarray

Get the video frames for a trajectory by a base index.

Parameters:

Name Type Description Default
dataset BaseSingleDataset

The dataset to retrieve the data from.

required
trajectory_id str

The ID of the trajectory.

required
key str

The key of the video.

required
base_index int

The base index of the trajectory.

required

Returns:

Type Description
ndarray

np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C)

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_video(
    self,
    trajectory_id: int,
    key: str,
    base_index: int,
) -> np.ndarray:
    """Get the video frames for a trajectory by a base index.

    Args:
        dataset (BaseSingleDataset): The dataset to retrieve the data from.
        trajectory_id (str): The ID of the trajectory.
        key (str): The key of the video.
        base_index (int): The base index of the trajectory.

    Returns:
        np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C)
    """
    # Get the step indices
    step_indices = self.delta_indices[key] + base_index
    # print(f"{step_indices=}")
    # Get the trajectory index
    trajectory_index = self.get_trajectory_index(trajectory_id)
    # Ensure the indices are within the valid range
    # This is equivalent to padding the video with extra frames at the beginning and end
    step_indices = np.maximum(step_indices, 0)
    step_indices = np.minimum(step_indices, self.trajectory_lengths[trajectory_index] - 1)
    assert key.startswith("video."), f"Video key must start with 'video.', got {key}"
    # Get the sub-key
    key = key.replace("video.", "")

    # Check if this is image-type data (PNG bytes in parquet)
    if key in self._image_type_video_keys:
        assert self.curr_traj_data is not None
        return self._read_images_from_parquet(trajectory_id, key, step_indices)

    video_path = self.get_video_path(trajectory_id, key)
    # Get the action/state timestamps for each frame in the video
    assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
    assert "timestamp" in self.curr_traj_data.columns, f"No timestamp found in {trajectory_id=}"
    timestamp: np.ndarray = self.curr_traj_data["timestamp"].to_numpy()
    # Get the corresponding video timestamps from the step indices
    video_timestamp = timestamp[step_indices]
    if self._lerobot_version == "v3.0":
        episode_meta = self.trajectory_ids_to_metadata.get(trajectory_id, {})
        from_timestamps = episode_meta.get("videos/from_timestamps", {})
        original_video_key = self.lerobot_modality_meta.video[key].original_key
        if original_video_key is None:
            original_video_key = key
        from_timestamp = float(from_timestamps.get(original_video_key, 0.0))
        video_timestamp = video_timestamp + from_timestamp

    # V2: batch-read next frame in same video file open
    next_frame_offset = None
    if (getattr(self, '_v2_next_frame_index', None) is not None
            and not key.startswith("wrist")):
        next_step = np.array([self._v2_next_frame_index])
        next_step = np.maximum(next_step, 0)
        next_step = np.minimum(next_step, self.trajectory_lengths[trajectory_index] - 1)
        next_ts = timestamp[next_step]
        if self._lerobot_version == "v3.0" and 'from_timestamp' in dir():
            next_ts = next_ts + from_timestamp
        next_frame_offset = len(video_timestamp)
        video_timestamp = np.concatenate([video_timestamp, next_ts])

    frames = get_frames_by_timestamps(
        video_path.as_posix(),
        video_timestamp,
        video_backend=self.video_backend,
        video_backend_kwargs=self.video_backend_kwargs,
    )

    if next_frame_offset is not None:
        self._v2_cached_next_frame = frames[next_frame_offset:]
        return frames[:next_frame_offset]

    return frames
get_state_or_action
get_state_or_action(trajectory_id: int, modality: str, key: str, base_index: int) -> np.ndarray

Get the state or action data for a trajectory by a base index. If the step indices are out of range, pad with the data: if the data is stored in absolute format, pad with the first or last step data; otherwise, pad with zero.

Parameters:

Name Type Description Default
dataset BaseSingleDataset

The dataset to retrieve the data from.

required
trajectory_id int

The ID of the trajectory.

required
modality str

The modality of the data.

required
key str

The key of the data.

required
base_index int

The base index of the trajectory.

required

Returns:

Type Description
ndarray

np.ndarray: The data for the trajectory and step indices.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_state_or_action(
    self,
    trajectory_id: int,
    modality: str,
    key: str,
    base_index: int,
) -> np.ndarray:
    """Get the state or action data for a trajectory by a base index.
    If the step indices are out of range, pad with the data:
        if the data is stored in absolute format, pad with the first or last step data;
        otherwise, pad with zero.

    Args:
        dataset (BaseSingleDataset): The dataset to retrieve the data from.
        trajectory_id (int): The ID of the trajectory.
        modality (str): The modality of the data.
        key (str): The key of the data.
        base_index (int): The base index of the trajectory.

    Returns:
        np.ndarray: The data for the trajectory and step indices.
    """
    # Get the step indices
    step_indices = self.delta_indices[key] + base_index
    # Get the trajectory index
    trajectory_index = self.get_trajectory_index(trajectory_id)
    # Get the maximum length of the trajectory
    max_length = self.trajectory_lengths[trajectory_index]
    assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}"
    # Get the sub-key, e.g. state.joint_angles -> joint_angles
    key = key.replace(modality + ".", "")
    # Get the lerobot key
    le_state_or_action_cfg = getattr(self.lerobot_modality_meta, modality)
    le_key = le_state_or_action_cfg[key].original_key
    if le_key is None:
        le_key = key
    # Get the data array, shape: (T, D)
    assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
    assert le_key in self.curr_traj_data.columns, f"No {le_key} found in {trajectory_id=}"
    data_array: np.ndarray = np.stack(self.curr_traj_data[le_key])  # type: ignore
    assert data_array.ndim == 2, f"Expected 2D array, got key {le_key} is{data_array.shape} array"
    le_indices = np.arange(
        le_state_or_action_cfg[key].start,
        le_state_or_action_cfg[key].end,
    )
    data_array = data_array[:, le_indices]
    # Get the state or action configuration
    state_or_action_cfg = getattr(self.metadata.modalities, modality)[key]

    # Pad the data
    return self.retrieve_data_and_pad(
        array=data_array,
        step_indices=step_indices,
        max_length=max_length,
        padding_strategy="first_last" if state_or_action_cfg.absolute else "zero",
        # padding_strategy="zero",           # HACK for realdata
    )
get_language
get_language(trajectory_id: int, key: str, base_index: int) -> list[str]

Get the language annotation data for a trajectory by step indices.

Parameters:

Name Type Description Default
dataset BaseSingleDataset

The dataset to retrieve the data from.

required
trajectory_id int

The ID of the trajectory.

required
key str

The key of the annotation.

required
base_index int

The base index of the trajectory.

required

Returns:

Type Description
list[str]

list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_language(
    self,
    trajectory_id: int,
    key: str,
    base_index: int,
) -> list[str]:
    """Get the language annotation data for a trajectory by step indices.

    Args:
        dataset (BaseSingleDataset): The dataset to retrieve the data from.
        trajectory_id (int): The ID of the trajectory.
        key (str): The key of the annotation.
        base_index (int): The base index of the trajectory.

    Returns:
        list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings.
    """
    assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
    # Get the step indices
    step_indices = self.delta_indices[key] + base_index
    # Get the trajectory index
    trajectory_index = self.get_trajectory_index(trajectory_id)
    # Get the maximum length of the trajectory
    max_length = self.trajectory_lengths[trajectory_index]
    # Get the end times corresponding to the closest indices
    step_indices = np.maximum(step_indices, 0)
    step_indices = np.minimum(step_indices, max_length - 1)
    # Get the annotations
    task_indices: list[int] = []
    assert key.startswith(
        "annotation."
    ), f"Language key must start with 'annotation.', got {key}"
    subkey = key.replace("annotation.", "")
    annotation_meta = self.lerobot_modality_meta.annotation
    assert annotation_meta is not None, f"Annotation metadata is None for {subkey}"
    assert (
        subkey in annotation_meta
    ), f"Annotation key {subkey} not found in metadata, available annotation keys: {annotation_meta.keys()}"
    subkey_meta = annotation_meta[subkey]
    original_key = subkey_meta.original_key
    if original_key is None:
        original_key = key
    for i in range(len(step_indices)): # 
        # task_indices.append(self.curr_traj_data[original_key][step_indices[i]].item())
        value = self.curr_traj_data[original_key].iloc[step_indices[i]] # TODO check v2.0 
        task_indices.append(value if isinstance(value, (int, float)) else value.item())

    return self.tasks.loc[task_indices]["task"].tolist()
get_data_by_modality
get_data_by_modality(trajectory_id: int, modality: str, key: str, base_index: int)

Get the data corresponding to the modality for a trajectory by a base index. This method will call the corresponding helper method based on the modality. See the helper methods for more details. NOTE: For the language modality, the data is padded with empty strings if no matching data is found.

Parameters:

Name Type Description Default
dataset BaseSingleDataset

The dataset to retrieve the data from.

required
trajectory_id int

The ID of the trajectory.

required
modality str

The modality of the data.

required
key str

The key of the data.

required
base_index int

The base index of the trajectory.

required
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_data_by_modality(
    self,
    trajectory_id: int,
    modality: str,
    key: str,
    base_index: int,
):
    """Get the data corresponding to the modality for a trajectory by a base index.
    This method will call the corresponding helper method based on the modality.
    See the helper methods for more details.
    NOTE: For the language modality, the data is padded with empty strings if no matching data is found.

    Args:
        dataset (BaseSingleDataset): The dataset to retrieve the data from.
        trajectory_id (int): The ID of the trajectory.
        modality (str): The modality of the data.
        key (str): The key of the data.
        base_index (int): The base index of the trajectory.
    """
    if modality == "video":
        return self.get_video(trajectory_id, key, base_index)
    elif modality == "state" or modality == "action":
        return self.get_state_or_action(trajectory_id, modality, key, base_index)
    elif modality == "language":
        return self.get_language(trajectory_id, key, base_index)
    else:
        raise ValueError(f"Invalid modality: {modality}")
CachedLeRobotSingleDataset
CachedLeRobotSingleDataset(img_resize: tuple[int, int] | None = None, *args, **kwargs)

Bases: LeRobotSingleDataset

This class caches the video frames for each trajectory and key. It is recommended to use this class if the video frames need to be accessed multiple times.

Parameters:

Name Type Description Default
resize_img tuple[int, int]

The size to resize the video frames to reduce memory usage.

required
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def __init__(self, img_resize: tuple[int, int] | None = None, *args, **kwargs):
    """
    This class caches the video frames for each trajectory and key.
    It is recommended to use this class if the video frames need to be accessed multiple times.

    Args:
        resize_img (tuple[int, int], optional): The size to resize the video frames to reduce memory usage.
    """
    # Convert img_resize to tuple if it is not already
    if img_resize is not None and not isinstance(img_resize, tuple):
        img_resize = tuple(img_resize)
        assert len(img_resize) == 2, f"Expected tuple of length 2, got {img_resize}"
    self.img_resize = img_resize

    # Initialize img_resize attribute first to ensure it exists
    super().__init__(*args, **kwargs)
    cached_frames: dict[str, np.ndarray] = {}

    for key in self.modality_keys["video"]:
        all_frames = []
        original_key = key
        key = key.replace("video.", "")
        is_image_type = key in self._image_type_video_keys
        for trajectory_id, trajectory_length in tqdm(
            zip(self.trajectory_ids, self.trajectory_lengths),
            total=len(self.trajectory_ids),
            desc=f"Caching {key} frames",
        ):
            if is_image_type:
                import io
                traj_data = self.get_trajectory_data(trajectory_id)
                orig_key = self.lerobot_modality_meta.video[key].original_key or key
                frame_list = []
                for idx in range(len(traj_data)):
                    img_data = traj_data.iloc[idx][orig_key]
                    if isinstance(img_data, dict):
                        img_bytes = img_data.get("bytes", img_data.get(b"bytes"))
                    else:
                        img_bytes = img_data
                    img = Image.open(io.BytesIO(img_bytes))
                    if img_resize is not None:
                        img = img.resize((img_resize[1], img_resize[0]))
                    frame_list.append(np.array(img))
                frames = np.stack(frame_list, axis=0)
            else:
                video_path = self.get_video_path(trajectory_id, key)
                frames = get_all_frames(
                    video_path.as_posix(),
                    video_backend=self.video_backend,
                    video_backend_kwargs=self.video_backend_kwargs,
                    resize_size=img_resize,
                )
            assert frames.ndim == 4, f"Expected 4D array, got {frames.shape} array"
            assert frames.shape[3] == 3, f"Expected 3 channels, got {frames.shape[3]} channels"

            # Apply image cropping if enabled and the video key is base_view
            # Note: crop_obs_camera functionality has been removed

            # assert (
            #     frames.shape[0] == trajectory_length
            # ), f"Expected {trajectory_length} frames, got {frames.shape[0]} frames"
            all_frames.append(frames)
        cached_frames[key] = np.concatenate(all_frames, axis=0)
        print(f"{key}: {cached_frames[key].shape}")
    self.cached_frames = cached_frames
    self.start_indices = np.cumsum(self.trajectory_lengths) - self.trajectory_lengths
get_step_data
get_step_data(trajectory_id: int, base_index: int) -> dict

Get the RAW data for a single step. No transforms are applied.

Parameters:

Name Type Description Default
trajectory_id str

The ID of the trajectory.

required
base_index int

The base index of the step.

required

Returns:

Name Type Description
dict dict

The data for the step.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_step_data(self, trajectory_id: int, base_index: int) -> dict:
    """Get the RAW data for a single step. No transforms are applied.

    Args:
        trajectory_id (str): The ID of the trajectory.
        base_index (int): The base index of the step.

    Returns:
        dict: The data for the step.
    """
    data = {}
    self.curr_traj_data = self.get_trajectory_data(trajectory_id)
    # Get the data for all modalities
    for modality in self.modality_keys:
        # Get the data corresponding to each key in the modality
        for key in self.modality_keys[modality]:
            data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index)
    return data
set_transforms_metadata
set_transforms_metadata(metadata: DatasetMetadata)

Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def set_transforms_metadata(self, metadata: DatasetMetadata):
    """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values."""
    if self.img_resize is not None:
        all_video_keys = [key for key in self.modality_keys["video"]]
        for key in metadata.modalities.video:
            if key in all_video_keys:
                metadata.modalities.video[key].resolution = self.img_resize
    super().set_transforms_metadata(metadata)
LeRobotMixtureDataset
LeRobotMixtureDataset(data_mixture: Sequence[tuple[LeRobotSingleDataset, float]], mode: str, balance_dataset_weights: bool = True, balance_trajectory_weights: bool = True, seed: int = 42, metadata_config: dict = {'percentile_mixing_method': 'min_max'}, **kwargs)

Bases: Dataset

A mixture of multiple datasets. This class samples a single dataset based on the dataset weights and then calls the __getitem__ method of the sampled dataset. It is recommended to modify the single dataset class instead of this class.

Initialize the mixture dataset.

Parameters:

Name Type Description Default
data_mixture list[tuple[LeRobotSingleDataset, float]]

Datasets and their corresponding weights.

required
mode str

If "train", getitem will return different samples every epoch; if "val" or "test", getitem will return the same sample every epoch.

required
balance_dataset_weights bool

If True, the weight of dataset will be multiplied by the total trajectory length of each dataset.

True
balance_trajectory_weights bool

If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting.

True
seed int

Random seed for sampling.

42
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def __init__(
    self,
    data_mixture: Sequence[tuple[LeRobotSingleDataset, float]],
    mode: str,
    balance_dataset_weights: bool = True,
    balance_trajectory_weights: bool = True,
    seed: int = 42,
    metadata_config: dict = {
        "percentile_mixing_method": "min_max",
    },
    **kwargs,
):
    """
    Initialize the mixture dataset.

    Args:
        data_mixture (list[tuple[LeRobotSingleDataset, float]]): Datasets and their corresponding weights.
        mode (str): If "train", __getitem__ will return different samples every epoch; if "val" or "test", __getitem__ will return the same sample every epoch.
        balance_dataset_weights (bool): If True, the weight of dataset will be multiplied by the total trajectory length of each dataset.
        balance_trajectory_weights (bool): If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting.
        seed (int): Random seed for sampling.
    """
    datasets: list[LeRobotSingleDataset] = []
    dataset_sampling_weights: list[float] = []
    for dataset, weight in data_mixture:
        # Check if dataset is valid and has data
        if len(dataset) == 0:
            print(f"Warning: Skipping empty dataset {dataset.dataset_name}")
            continue
        datasets.append(dataset)
        dataset_sampling_weights.append(weight)

    if len(datasets) == 0:
        raise ValueError("No valid datasets found in the mixture. All datasets are empty.")

    self.datasets = datasets
    self.balance_dataset_weights = balance_dataset_weights
    self.balance_trajectory_weights = balance_trajectory_weights
    self.seed = seed
    self.mode = mode
    self.data_cfg = kwargs["data_cfg"] if "data_cfg" in kwargs else None

    # Set properties for sampling

    # 1. Dataset lengths
    self._dataset_lengths = np.array([len(dataset) for dataset in self.datasets])
    print(f"Dataset lengths: {self._dataset_lengths}")

    # 2. Dataset sampling weights
    self._dataset_sampling_weights = np.array(dataset_sampling_weights)

    if self.balance_dataset_weights:
        self._dataset_sampling_weights *= self._dataset_lengths

    # Check for zero or negative weights before normalization
    if np.any(self._dataset_sampling_weights <= 0):
        print(f"Warning: Found zero or negative sampling weights: {self._dataset_sampling_weights}")
        # Set minimum weight to prevent division issues
        self._dataset_sampling_weights = np.maximum(self._dataset_sampling_weights, 1e-8)

    # Normalize weights
    weights_sum = self._dataset_sampling_weights.sum()
    if weights_sum == 0 or np.isnan(weights_sum):
        print(f"Error: Invalid weights sum: {weights_sum}")
        # Fallback to equal weights
        self._dataset_sampling_weights = np.ones(len(self.datasets)) / len(self.datasets)
        print(f"Fallback to equal weights")
    else:
        self._dataset_sampling_weights /= weights_sum

    # 3. Trajectory sampling weights
    self._trajectory_sampling_weights: list[np.ndarray] = []
    for i, dataset in enumerate(self.datasets):
        trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths))
        if self.balance_trajectory_weights:
            trajectory_sampling_weights *= dataset.trajectory_lengths

        # Check for zero or negative weights before normalization
        if np.any(trajectory_sampling_weights <= 0):
            print(f"Warning: Dataset {i} has zero or negative trajectory weights")
            trajectory_sampling_weights = np.maximum(trajectory_sampling_weights, 1e-8)

        # Normalize weights
        weights_sum = trajectory_sampling_weights.sum()
        if weights_sum == 0 or np.isnan(weights_sum):
            print(f"Error: Dataset {i} has invalid trajectory weights sum: {weights_sum}")
            # Fallback to equal weights
            trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) / len(dataset.trajectory_lengths)
        else:
            trajectory_sampling_weights /= weights_sum

        self._trajectory_sampling_weights.append(trajectory_sampling_weights)

    # 4. Primary dataset indices
    self._primary_dataset_indices = np.array(dataset_sampling_weights) == 1.0
    if not np.any(self._primary_dataset_indices):
        print(f"Warning: No dataset with weight 1.0 found. Original weights: {dataset_sampling_weights}")
        # Fallback: use the dataset(s) with maximum weight as primary
        max_weight = max(dataset_sampling_weights)
        self._primary_dataset_indices = np.array(dataset_sampling_weights) == max_weight
        print(f"Using datasets with maximum weight {max_weight} as primary: {self._primary_dataset_indices}")

    if not np.any(self._primary_dataset_indices):
        # This should never happen, but just in case
        print("Error: Still no primary dataset found. Using first dataset as primary.")
        self._primary_dataset_indices = np.zeros(len(self.datasets), dtype=bool)
        self._primary_dataset_indices[0] = True

    # Set the epoch and sample the first epoch
    self.set_epoch(0)

    self._sequential_step_sampling = True
    if self.data_cfg is not None:
        seq_cfg = self.data_cfg.get("sequential_step_sampling", True)
        self._sequential_step_sampling = seq_cfg not in ["False", False]

    self._step_order: list[np.ndarray] = []
    self._step_pos: list[int] = []
    if self._sequential_step_sampling:
        for dataset in self.datasets:
            self._step_order.append(np.arange(len(dataset.all_steps)))
            if self.mode == "train":
                rng = np.random.default_rng(self.seed)
                rng.shuffle(self._step_order[-1])
            self._step_pos.append(0)

    self.update_metadata(metadata_config)
dataset_lengths property
dataset_lengths: ndarray

The lengths of each dataset.

dataset_sampling_weights property
dataset_sampling_weights: ndarray

The sampling weights for each dataset.

trajectory_sampling_weights property
trajectory_sampling_weights: list[ndarray]

The sampling weights for each trajectory in each dataset.

primary_dataset_indices property
primary_dataset_indices: ndarray

The indices of the primary datasets.

set_epoch
set_epoch(epoch: int)

Set the epoch for the dataset.

Parameters:

Name Type Description Default
epoch int

The epoch to set.

required
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def set_epoch(self, epoch: int):
    """Set the epoch for the dataset.

    Args:
        epoch (int): The epoch to set.
    """
    self.epoch = epoch
sample_step
sample_step(index: int) -> tuple[LeRobotSingleDataset, int, int]

Sample a single step from the dataset.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def sample_step(self, index: int) -> tuple[LeRobotSingleDataset, int, int]:
    """Sample a single step from the dataset."""
    # return self.sampled_steps[index]

    # Set seed
    seed = index if self.mode != "train" else safe_hash((self.epoch, index, self.seed))
    rng = np.random.default_rng(seed)

    # Sample dataset
    dataset_index = rng.choice(len(self.datasets), p=self.dataset_sampling_weights)
    dataset = self.datasets[dataset_index]

    # Sample trajectory
    # trajectory_index = rng.choice(
    #     len(dataset.trajectory_ids), p=self.trajectory_sampling_weights[dataset_index]
    # )
    # trajectory_id = dataset.trajectory_ids[trajectory_index]

    # # Sample step
    # base_index = rng.choice(dataset.trajectory_lengths[trajectory_index])
    # return dataset, trajectory_id, base_index
    if len(dataset.all_steps) == 0:
        raise ValueError(f"Dataset {dataset.dataset_name} has no steps.")

    if not self._sequential_step_sampling:
        single_step_index = rng.choice(len(dataset.all_steps))
    else:
        step_pos = self._step_pos[dataset_index]
        if step_pos >= len(dataset.all_steps):
            order = np.arange(len(dataset.all_steps))
            if self.mode == "train":
                seed = safe_hash((self.epoch, dataset_index, self.seed, step_pos))
                rng = np.random.default_rng(seed)
                rng.shuffle(order)
            self._step_order[dataset_index] = order
            step_pos = 0

        single_step_index = self._step_order[dataset_index][step_pos]
        self._step_pos[dataset_index] = step_pos + 1
    trajectory_id, base_index = dataset.all_steps[single_step_index]
    return dataset, trajectory_id, base_index
compute_overall_statistics staticmethod
compute_overall_statistics(per_task_stats: list[dict[str, dict[str, list[float] | ndarray]]], dataset_sampling_weights: list[float] | ndarray, percentile_mixing_method: str = 'weighted_average') -> dict[str, dict[str, list[float]]]

Computes overall statistics from per-task statistics using dataset sample weights.

Parameters:

Name Type Description Default
per_task_stats list[dict[str, dict[str, list[float] | ndarray]]]

List of per-task statistics.

required
Example format of one element in the per-task statistics list

{ "state.gripper": { "min": [...], "max": [...], "mean": [...], "std": [...], "q01": [...], "q99": [...], }, ... }

required
dataset_sampling_weights list[float] | ndarray

List of sample weights for each task.

required
percentile_mixing_method str

The method to mix the percentiles, either "weighted_average" or "weighted_std".

'weighted_average'

Returns:

Type Description
dict[str, dict[str, list[float]]]

A dict of overall statistics per modality.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
@staticmethod
def compute_overall_statistics(
    per_task_stats: list[dict[str, dict[str, list[float] | np.ndarray]]],
    dataset_sampling_weights: list[float] | np.ndarray,
    percentile_mixing_method: str = "weighted_average",
) -> dict[str, dict[str, list[float]]]:
    """
    Computes overall statistics from per-task statistics using dataset sample weights.

    Args:
        per_task_stats: List of per-task statistics.
        Example format of one element in the per-task statistics list:
            {
                "state.gripper": {
                    "min": [...],
                    "max": [...],
                    "mean": [...],
                    "std": [...],
                    "q01": [...],
                    "q99": [...],
                },
                ...
            }
        dataset_sampling_weights: List of sample weights for each task.
        percentile_mixing_method: The method to mix the percentiles, either "weighted_average" or "weighted_std".

    Returns:
        A dict of overall statistics per modality.
    """
    # Normalize the sample weights to sum to 1
    dataset_sampling_weights = np.array(dataset_sampling_weights)
    normalized_weights = dataset_sampling_weights / dataset_sampling_weights.sum()

    # Initialize overall statistics dict
    overall_stats: dict[str, dict[str, list[float]]] = {}

    # Get the list of modality keys
    modality_keys = per_task_stats[0].keys()

    for modality in modality_keys:
        # Number of dimensions (assuming consistent across tasks)
        num_dims = len(per_task_stats[0][modality]["mean"])

        # Initialize accumulators for means and variances
        weighted_means = np.zeros(num_dims)
        weighted_squares = np.zeros(num_dims)

        # Collect min, max, q01, q99 from all tasks
        min_list = []
        max_list = []
        q01_list = []
        q99_list = []

        for task_idx, task_stats in enumerate(per_task_stats):
            w_i = normalized_weights[task_idx]
            stats = task_stats[modality]
            means = np.array(stats["mean"])
            stds = np.array(stats["std"])

            # Update weighted sums for mean and variance
            weighted_means += w_i * means
            weighted_squares += w_i * (stds**2 + means**2)

            # Collect min, max, q01, q99
            min_list.append(stats["min"])
            max_list.append(stats["max"])
            q01_list.append(stats["q01"])
            q99_list.append(stats["q99"])

        # Compute overall mean
        overall_mean = weighted_means.tolist()

        # Compute overall variance and std deviation
        overall_variance = weighted_squares - weighted_means**2
        overall_std = np.sqrt(overall_variance).tolist()

        # Compute overall min and max per dimension
        overall_min = np.min(np.array(min_list), axis=0).tolist()
        overall_max = np.max(np.array(max_list), axis=0).tolist()

        # Compute overall q01 and q99 per dimension
        # Use weighted average of per-task quantiles
        q01_array = np.array(q01_list)
        q99_array = np.array(q99_list)
        if percentile_mixing_method == "weighted_average":
            weighted_q01 = np.average(q01_array, axis=0, weights=normalized_weights).tolist()
            weighted_q99 = np.average(q99_array, axis=0, weights=normalized_weights).tolist()
            # std_q01 = np.std(q01_array, axis=0).tolist()
            # std_q99 = np.std(q99_array, axis=0).tolist()
            # print(modality)
            # print(f"{std_q01=}, {std_q99=}")
            # print(f"{weighted_q01=}, {weighted_q99=}")
        elif percentile_mixing_method == "min_max":
            weighted_q01 = np.min(q01_array, axis=0).tolist()
            weighted_q99 = np.max(q99_array, axis=0).tolist()
        else:
            raise ValueError(f"Invalid percentile mixing method: {percentile_mixing_method}")

        # Store the overall statistics for the modality
        overall_stats[modality] = {
            "min": overall_min,
            "max": overall_max,
            "mean": overall_mean,
            "std": overall_std,
            "q01": weighted_q01,
            "q99": weighted_q99,
        }

    return overall_stats
merge_metadata staticmethod
merge_metadata(metadatas: list[DatasetMetadata], dataset_sampling_weights: list[float], percentile_mixing_method: str) -> DatasetMetadata

Merge multiple metadata into one.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
@staticmethod
def merge_metadata(
    metadatas: list[DatasetMetadata],
    dataset_sampling_weights: list[float],
    percentile_mixing_method: str,
) -> DatasetMetadata:
    """Merge multiple metadata into one."""
    # Convert to dicts
    metadata_dicts = [metadata.model_dump(mode="json") for metadata in metadatas]
    # Create a new metadata dict
    merged_metadata = {}

    # Check all metadata have the same embodiment tag
    assert all(
        metadata.embodiment_tag == metadatas[0].embodiment_tag for metadata in metadatas
    ), "All metadata must have the same embodiment tag"
    merged_metadata["embodiment_tag"] = metadatas[0].embodiment_tag

    # Merge the dataset statistics
    dataset_statistics = {}
    dataset_statistics["state"] = LeRobotMixtureDataset.compute_overall_statistics(
        per_task_stats=[m["statistics"]["state"] for m in metadata_dicts],
        dataset_sampling_weights=dataset_sampling_weights,
        percentile_mixing_method=percentile_mixing_method,
    )
    dataset_statistics["action"] = LeRobotMixtureDataset.compute_overall_statistics(
        per_task_stats=[m["statistics"]["action"] for m in metadata_dicts],
        dataset_sampling_weights=dataset_sampling_weights,
        percentile_mixing_method=percentile_mixing_method,
    )
    merged_metadata["statistics"] = dataset_statistics

    # Merge the modality configs
    modality_configs = defaultdict(set)
    for metadata in metadata_dicts:
        for modality, configs in metadata["modalities"].items():
            modality_configs[modality].add(json.dumps(configs))
    merged_metadata["modalities"] = {}
    for modality, configs in modality_configs.items():
        # Check that all modality configs correspond to the same tag matches
        assert (
            len(configs) == 1
        ), f"Multiple modality configs for modality {modality}: {list(configs)}"
        merged_metadata["modalities"][modality] = json.loads(configs.pop())

    return DatasetMetadata.model_validate(merged_metadata)
update_metadata
update_metadata(metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None

Merge multiple metadatas into one and set the transforms with the merged metadata.

Parameters:

Name Type Description Default
metadata_config dict

Configuration for the metadata. "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max". weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets. min_max: Use the min of the 1st percentile and max of the 99th percentile.

required
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def update_metadata(self, metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None:
    """
    Merge multiple metadatas into one and set the transforms with the merged metadata.

    Args:
        metadata_config (dict): Configuration for the metadata.
            "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max".
                weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets.
                min_max: Use the min of the 1st percentile and max of the 99th percentile.
    """
    # If cached path is provided, try to load and apply
    if cached_statistics_path is not None:
        try:
            cached_stats = self.load_merged_statistics(cached_statistics_path)
            self.apply_cached_statistics(cached_stats)
            return
        except (FileNotFoundError, KeyError, ValidationError) as e:
            print(f"Failed to load cached statistics: {e}")
            print("Falling back to computing statistics from scratch...")

    self.tag = EmbodimentTag.NEW_EMBODIMENT.value
    self.merged_metadata: dict[str, DatasetMetadata] = {}
    # Group metadata by tag
    all_metadatas: dict[str, list[DatasetMetadata]] = {}
    for dataset in self.datasets:
        if dataset.tag not in all_metadatas:
            all_metadatas[dataset.tag] = []
        all_metadatas[dataset.tag].append(dataset.metadata)
    for tag, metadatas in all_metadatas.items():
        self.merged_metadata[tag] = self.merge_metadata(
            metadatas=metadatas,
            dataset_sampling_weights=self.dataset_sampling_weights.tolist(),
            percentile_mixing_method=metadata_config["percentile_mixing_method"],
        )
    for dataset in self.datasets:
        dataset.set_transforms_metadata(self.merged_metadata[dataset.tag])
save_dataset_statistics
save_dataset_statistics(save_path: Path | str, format: str = 'json') -> None

Save merged dataset statistics to specified path in the required format. Only includes statistics for keys that are actually used in the datasets. Gripper-related keys will be placed at the end.

Parameters:

Name Type Description Default
save_path Path | str

Path to save the statistics file

required
format str

Save format, currently only supports "json"

'json'
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def save_dataset_statistics(self, save_path: Path | str, format: str = "json") -> None:
    """
    Save merged dataset statistics to specified path in the required format.
    Only includes statistics for keys that are actually used in the datasets.
    Gripper-related keys will be placed at the end.

    Args:
        save_path (Path | str): Path to save the statistics file
        format (str): Save format, currently only supports "json"
    """
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)

    # Build the data structure to save
    statistics_data = {}

    # Collect actually used keys from all datasets
    all_used_action_keys = []
    all_used_state_keys = []

    for dataset in self.datasets:
        used_action_keys, used_state_keys = get_used_modality_keys(dataset.modality_keys)
        for used_action_key in used_action_keys:
            if used_action_key not in all_used_action_keys:
                all_used_action_keys.append(used_action_key)
        for used_state_key in used_state_keys:
            if used_state_key not in all_used_state_keys:
                all_used_state_keys.append(used_state_key)

    # Organize statistics by tag
    for tag, merged_metadata in self.merged_metadata.items():
        tag_stats = {}

        # Process action statistics
        if hasattr(merged_metadata.statistics, 'action') and merged_metadata.statistics.action:
            action_stats = merged_metadata.statistics.action

            # Filter and reorder keys - iterate in all_used_action_keys order
            non_gripper_keys = []
            gripper_keys = []

            for key in all_used_action_keys:
                if key in action_stats:
                    non_gripper_keys.append(key)

            reordered_keys = non_gripper_keys + gripper_keys

            filtered_action_stats = {}
            for key in reordered_keys:
                filtered_action_stats[key] = action_stats[key]

            if filtered_action_stats:
                combined_action_stats = combine_modality_stats(filtered_action_stats)

                mask = generate_action_mask_for_used_keys(
                    merged_metadata.modalities.action, filtered_action_stats.keys()
                )
                combined_action_stats["mask"] = mask

                # Auto-detect and save norm_mode from dataset transforms
                # so eval can correctly unnormalize without guessing
                norm_mode = self._detect_action_norm_mode()
                if norm_mode is not None:
                    combined_action_stats["norm_mode"] = norm_mode
                    print(f"[save_dataset_statistics] Saved norm_mode='{norm_mode}' for tag '{tag}'")
                else:
                    print(f"[save_dataset_statistics] WARNING: norm_mode detection failed for tag '{tag}', "
                          f"no norm_mode saved. Eval will default to 'q99'.")

                tag_stats["action"] = combined_action_stats

        # Process state statistics
        if hasattr(merged_metadata.statistics, 'state') and merged_metadata.statistics.state:
            state_stats = merged_metadata.statistics.state

            # Filter and reorder keys - iterate in all_used_state_keys order
            # Filter and reorder keys - iterate in all_used_state_keys order
            non_gripper_keys = []
            gripper_keys = []

            for key in all_used_state_keys:
                if key in state_stats:
                    non_gripper_keys.append(key)

            reordered_keys = non_gripper_keys + gripper_keys

            filtered_state_stats = {}
            for key in reordered_keys:
                filtered_state_stats[key] = state_stats[key]

            if filtered_state_stats:
                combined_state_stats = combine_modality_stats(filtered_state_stats)
                tag_stats["state"] = combined_state_stats

        # Add dataset counts
        tag_stats.update(self._get_dataset_counts(tag))

        statistics_data[tag] = tag_stats

    # Save file
    if format.lower() == "json":
        if not str(save_path).endswith('.json'):
            save_path = save_path.with_suffix('.json')
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(statistics_data, f, indent=2, ensure_ascii=False)
    else:
        raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.")

    print(f"Merged dataset statistics saved to: {save_path}")
    print(f"Used action keys (reordered): {list(all_used_action_keys)}")
    print(f"Used state keys (reordered): {list(all_used_state_keys)}")
load_merged_statistics classmethod
load_merged_statistics(load_path: Path | str) -> dict

Load merged dataset statistics from file.

Parameters:

Name Type Description Default
load_path Path | str

Path to the statistics file

required

Returns:

Name Type Description
dict dict

Dictionary containing merged statistics

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
@classmethod
def load_merged_statistics(cls, load_path: Path | str) -> dict:
    """
    Load merged dataset statistics from file.

    Args:
        load_path (Path | str): Path to the statistics file

    Returns:
        dict: Dictionary containing merged statistics
    """
    load_path = Path(load_path)
    if not load_path.exists():
        raise FileNotFoundError(f"Statistics file not found: {load_path}")

    if load_path.suffix.lower() == '.json':
        with open(load_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    elif load_path.suffix.lower() == '.pkl':
        import pickle
        with open(load_path, 'rb') as f:
            return pickle.load(f)
    else:
        raise ValueError(f"Unsupported file format: {load_path.suffix}")
apply_cached_statistics
apply_cached_statistics(cached_statistics: dict) -> None

Apply cached statistics to avoid recomputation.

Parameters:

Name Type Description Default
cached_statistics dict

Statistics loaded from file

required
Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def apply_cached_statistics(self, cached_statistics: dict) -> None:
    """
    Apply cached statistics to avoid recomputation.

    Args:
        cached_statistics (dict): Statistics loaded from file
    """
    # Validate that cached statistics match current datasets
    if "metadata" in cached_statistics:
        cached_dataset_names = set(cached_statistics["metadata"]["dataset_names"])
        current_dataset_names = set(dataset.dataset_name for dataset in self.datasets)

        if cached_dataset_names != current_dataset_names:
            print("Warning: Cached statistics dataset names don't match current datasets.")
            print(f"Cached: {cached_dataset_names}")
            print(f"Current: {current_dataset_names}")
            return

    # Apply cached statistics
    self.merged_metadata = {}
    for tag, stats_data in cached_statistics.items():
        if tag == "metadata":  # Skip metadata field
            continue

        # Convert back to DatasetMetadata format
        metadata_dict = {
            "embodiment_tag": tag,
            "statistics": {
                "action": {},
                "state": {}
            },
            "modalities": {}
        }

        # Convert action statistics back
        if "action" in stats_data:
            action_data = stats_data["action"]
            # This is simplified - you may need to split back to sub-keys
            metadata_dict["statistics"]["action"] = action_data

        # Convert state statistics back
        if "state" in stats_data:
            state_data = stats_data["state"]
            metadata_dict["statistics"]["state"] = state_data

        self.merged_metadata[tag] = DatasetMetadata.model_validate(metadata_dict)

    # Update transforms metadata for each dataset
    for dataset in self.datasets:
        if dataset.tag in self.merged_metadata:
            dataset.set_transforms_metadata(self.merged_metadata[dataset.tag])

    print(f"Applied cached statistics for {len(self.merged_metadata)} embodiment tags.")
calculate_dataset_statistics
calculate_dataset_statistics(parquet_paths: list[Path]) -> dict

Calculate the dataset statistics of all columns for a list of parquet files.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def calculate_dataset_statistics(parquet_paths: list[Path]) -> dict:
    """Calculate the dataset statistics of all columns for a list of parquet files."""
    # Dataset statistics
    all_low_dim_data_list = []
    # Collect all the data
    # parquet_paths = parquet_paths[:3]
    for parquet_path in tqdm(
        sorted(list(parquet_paths)),
        desc="Collecting all parquet files...",
    ):
        # Load the parquet file
        parquet_data = _safe_read_parquet(parquet_path)
        parquet_data = parquet_data
        all_low_dim_data_list.append(parquet_data)

    all_low_dim_data = pd.concat(all_low_dim_data_list, axis=0)
    # Compute dataset statistics
    dataset_statistics = {}
    for le_modality in tqdm(all_low_dim_data.columns, desc="Processing modalities"):
        print(le_modality)
        if "task_info" in le_modality:
            continue
        print(f"Computing statistics for {le_modality}...")
        # 检查数据是否为空或无效
        try:
            np_data = np.vstack(
                [np.asarray(x, dtype=np.float32) for x in all_low_dim_data[le_modality]]
            )
        except Exception as e:
            print(f"Warning: Failed to process modality {le_modality} due to error: {e}")
            continue  

        dataset_statistics[le_modality] = {
            "mean": np.mean(np_data, axis=0).tolist(),
            "std": np.std(np_data, axis=0).tolist(),
            "min": np.min(np_data, axis=0).tolist(),
            "max": np.max(np_data, axis=0).tolist(),
            "q01": np.quantile(np_data, 0.01, axis=0).tolist(),
            "q99": np.quantile(np_data, 0.99, axis=0).tolist(),
        }
    return dataset_statistics
calculate_delta_action_statistics
calculate_delta_action_statistics(parquet_paths: list[Path], lerobot_modality_meta: LeRobotModalityMetadata, action_keys_full: list[str], state_keys_full: list[str], action_indices: list[int], state_indices: list[int], action_mode_apply_keys: list[str] | None = None, action_mode_state_map: dict[str, str] | None = None, base_stats: dict | None = None) -> dict

Calculate action statistics using delta mode.

Rule
  • For t>0: a_t - a_{t-1}
  • For t=0: a_0 - s_0

Mapping rule (only two cases): 1) Use explicit action_mode_state_map if provided. 2) Otherwise, replace 'action.' with 'state.' directly.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def calculate_delta_action_statistics(
    parquet_paths: list[Path],
    lerobot_modality_meta: "LeRobotModalityMetadata",
    action_keys_full: list[str],
    state_keys_full: list[str],
    action_indices: list[int],
    state_indices: list[int],
    action_mode_apply_keys: list[str] | None = None,
    action_mode_state_map: dict[str, str] | None = None,
    base_stats: dict | None = None,
) -> dict:
    """
    Calculate action statistics using delta mode.

    Rule:
      - For t>0: a_t - a_{t-1}
      - For t=0: a_0 - s_0

    Mapping rule (only two cases):
      1) Use explicit action_mode_state_map if provided.
      2) Otherwise, replace 'action.' with 'state.' directly.
    """
    if base_stats is None:
        base_stats = calculate_dataset_statistics(parquet_paths)

    action_col_slices = _get_action_col_slices(
        lerobot_modality_meta, action_keys_full, state_keys_full, action_mode_apply_keys, action_mode_state_map
    )
    if not action_col_slices:
        raise ValueError("No action columns found in the dataset.")

    def _get_chunk(array: np.ndarray, step_indices: np.ndarray, padding_strategy: str) -> np.ndarray:
        max_length = array.shape[0]
        front_padding = step_indices < 0
        end_padding = step_indices >= max_length
        padding_positions = np.logical_or(front_padding, end_padding)
        output = np.zeros((len(step_indices), array.shape[1]), dtype=array.dtype)
        if (~padding_positions).any():
            output[~padding_positions] = array[step_indices[~padding_positions]]
        if padding_positions.any():
            if padding_strategy == "first_last":
                output[front_padding] = array[0]
                output[end_padding] = array[-1]
            elif padding_strategy == "zero":
                output[padding_positions] = 0
            else:
                raise ValueError(f"Invalid padding strategy: {padding_strategy}")
        return output

    accum: dict[str, list[np.ndarray]] = {col: [] for col in action_col_slices.keys()}
    for parquet_path in tqdm(sorted(list(parquet_paths)), desc="Collecting delta action stats"):
        data = _safe_read_parquet(parquet_path)
        trajectory_length = len(data)
        for action_col, slice_list in action_col_slices.items():
            if action_col not in data.columns:
                raise ValueError(f"{action_col} not found in parquet columns.")
            action_matrix = np.stack(data[action_col])
            action_padding_ref = slice_list[0][3]
            prepared_slices = []
            for a_slice, state_col, s_slice, action_padding, state_padding in slice_list:
                if state_col not in data.columns:
                    raise ValueError(f"{state_col} not found in parquet columns.")
                state_matrix = np.stack(data[state_col])
                state_part_full = state_matrix[:, s_slice[0] : s_slice[1]]
                prepared_slices.append((a_slice, state_part_full, state_padding))
            for base_index in range(trajectory_length):
                action_steps = np.array(action_indices) + base_index
                action_chunk_full = _get_chunk(action_matrix, action_steps, action_padding_ref)

                for a_slice, state_part_full, state_padding in prepared_slices:
                    action_part_chunk = action_chunk_full[:, a_slice[0] : a_slice[1]]
                    state_chunk = _get_chunk(state_part_full, np.array(state_indices) + base_index, state_padding)
                    if action_part_chunk.shape[1] != state_chunk.shape[1]:
                        raise ValueError(f"Action/state dim mismatch for {action_col}:{a_slice}")

                    out = action_part_chunk.copy()
                    if len(out) > 1:
                        out[1:] = action_part_chunk[1:] - action_part_chunk[:-1]
                    out[0] = action_part_chunk[0] - state_chunk[0]
                    action_chunk_full[:, a_slice[0] : a_slice[1]] = out

                accum[action_col].append(action_chunk_full)

    delta_stats = copy.deepcopy(base_stats)
    for action_col, series_list in accum.items():
        if not series_list:
            continue
        all_values = np.concatenate(series_list, axis=0).astype(np.float32)
        delta_stats[action_col] = {
            "mean": np.mean(all_values, axis=0).tolist(),
            "std": np.std(all_values, axis=0).tolist(),
            "min": np.min(all_values, axis=0).tolist(),
            "max": np.max(all_values, axis=0).tolist(),
            "q01": np.quantile(all_values, 0.01, axis=0).tolist(),
            "q99": np.quantile(all_values, 0.99, axis=0).tolist(),
        }
    return delta_stats
calculate_rel_action_statistics
calculate_rel_action_statistics(parquet_paths: list[Path], lerobot_modality_meta: LeRobotModalityMetadata, action_keys_full: list[str], state_keys_full: list[str], action_indices: list[int], state_indices: list[int], action_mode_apply_keys: list[str] | None = None, action_mode_state_map: dict[str, str] | None = None, base_stats: dict | None = None) -> dict

Calculate action statistics using rel mode.

Rule
  • For all t: a_t - s_0

Mapping rule (only two cases): 1) Use explicit action_mode_state_map if provided. 2) Otherwise, replace 'action.' with 'state.' directly.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def calculate_rel_action_statistics(
    parquet_paths: list[Path],
    lerobot_modality_meta: "LeRobotModalityMetadata",
    action_keys_full: list[str],
    state_keys_full: list[str],
    action_indices: list[int],
    state_indices: list[int],
    action_mode_apply_keys: list[str] | None = None,
    action_mode_state_map: dict[str, str] | None = None,
    base_stats: dict | None = None,
) -> dict:
    """
    Calculate action statistics using rel mode.

    Rule:
      - For all t: a_t - s_0

    Mapping rule (only two cases):
      1) Use explicit action_mode_state_map if provided.
      2) Otherwise, replace 'action.' with 'state.' directly.
    """
    if base_stats is None:
        base_stats = calculate_dataset_statistics(parquet_paths)

    action_col_slices = _get_action_col_slices(
        lerobot_modality_meta, action_keys_full, state_keys_full, action_mode_apply_keys, action_mode_state_map
    )
    if not action_col_slices:
        raise ValueError("No action columns found in the dataset.")

    def _get_chunk(array: np.ndarray, step_indices: np.ndarray, padding_strategy: str) -> np.ndarray:
        max_length = array.shape[0]
        front_padding = step_indices < 0
        end_padding = step_indices >= max_length
        padding_positions = np.logical_or(front_padding, end_padding)
        output = np.zeros((len(step_indices), array.shape[1]), dtype=array.dtype)
        if (~padding_positions).any():
            output[~padding_positions] = array[step_indices[~padding_positions]]
        if padding_positions.any():
            if padding_strategy == "first_last":
                output[front_padding] = array[0]
                output[end_padding] = array[-1]
            elif padding_strategy == "zero":
                output[padding_positions] = 0
            else:
                raise ValueError(f"Invalid padding strategy: {padding_strategy}")
        return output

    accum: dict[str, list[np.ndarray]] = {col: [] for col in action_col_slices.keys()}
    for parquet_path in tqdm(sorted(list(parquet_paths)), desc="Collecting rel action stats"):
        data = _safe_read_parquet(parquet_path)
        trajectory_length = len(data)
        for action_col, slice_list in action_col_slices.items():
            if action_col not in data.columns:
                raise ValueError(f"{action_col} not found in parquet columns.")
            action_matrix = np.stack(data[action_col])
            action_padding_ref = slice_list[0][3]
            prepared_slices = []
            for a_slice, state_col, s_slice, action_padding, state_padding in slice_list:
                if state_col not in data.columns:
                    raise ValueError(f"{state_col} not found in parquet columns.")
                state_matrix = np.stack(data[state_col])
                state_part_full = state_matrix[:, s_slice[0] : s_slice[1]]
                prepared_slices.append((a_slice, state_part_full, state_padding))
            for base_index in range(trajectory_length):
                action_steps = np.array(action_indices) + base_index
                action_chunk_full = _get_chunk(action_matrix, action_steps, action_padding_ref)

                for a_slice, state_part_full, state_padding in prepared_slices:
                    action_part_chunk = action_chunk_full[:, a_slice[0] : a_slice[1]]
                    state_chunk = _get_chunk(state_part_full, np.array(state_indices) + base_index, state_padding)
                    if action_part_chunk.shape[1] != state_chunk.shape[1]:
                        raise ValueError(f"Action/state dim mismatch for {action_col}:{a_slice}")

                    out = action_part_chunk - state_chunk[0]
                    action_chunk_full[:, a_slice[0] : a_slice[1]] = out

                accum[action_col].append(action_chunk_full)

    rel_stats = copy.deepcopy(base_stats)
    for action_col, series_list in accum.items():
        if not series_list:
            continue
        all_values = np.concatenate(series_list, axis=0).astype(np.float32)
        rel_stats[action_col] = {
            "mean": np.mean(all_values, axis=0).tolist(),
            "std": np.std(all_values, axis=0).tolist(),
            "min": np.min(all_values, axis=0).tolist(),
            "max": np.max(all_values, axis=0).tolist(),
            "q01": np.quantile(all_values, 0.01, axis=0).tolist(),
            "q99": np.quantile(all_values, 0.99, axis=0).tolist(),
        }
    return rel_stats
combine_modality_stats
combine_modality_stats(modality_stats: dict) -> dict

Combine statistics from all sub-keys under a modality.

Parameters:

Name Type Description Default
modality_stats dict

Statistics for a modality, containing multiple sub-keys. Each sub-key contains DatasetStatisticalValues object.

required

Returns:

Name Type Description
dict dict

Combined statistics

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def combine_modality_stats(modality_stats: dict) -> dict:
    """
    Combine statistics from all sub-keys under a modality.

    Args:
        modality_stats (dict): Statistics for a modality, containing multiple sub-keys.
                               Each sub-key contains DatasetStatisticalValues object.

    Returns:
        dict: Combined statistics
    """
    combined_stats = {
        "mean": [],
        "std": [],
        "max": [],
        "min": [],
        "q01": [],
        "q99": []
    }

    # Combine statistics in sub-key order
    for subkey in modality_stats.keys():
        subkey_stats = modality_stats[subkey]  # This is a DatasetStatisticalValues object

        # Convert DatasetStatisticalValues to dict-like access
        for stat_name in ["mean", "std", "max", "min", "q01", "q99"]:
            stat_value = getattr(subkey_stats, stat_name)
            if isinstance(stat_value, (list, tuple)):
                combined_stats[stat_name].extend(stat_value)
            else:
                # Handle NDArray case - convert to list
                if hasattr(stat_value, 'tolist'):
                    combined_stats[stat_name].extend(stat_value.tolist())
                else:
                    combined_stats[stat_name].append(float(stat_value))

    return combined_stats
generate_action_mask_for_used_keys
generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]

Generate mask based on action modalities, but only for used keys. Gripper-related are False, others are True.

Parameters:

Name Type Description Default
action_modalities dict

Configuration information for action modalities.

required
used_action_keys_ordered

Iterable of actually used action keys in the correct order.

required

Returns:

Type Description
list[bool]

list[bool]: List of mask values

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]:
    """
    Generate mask based on action modalities, but only for used keys.
    Gripper-related are False, others are True.

    Args:
        action_modalities (dict): Configuration information for action modalities.
        used_action_keys_ordered: Iterable of actually used action keys in the correct order.

    Returns:
        list[bool]: List of mask values
    """
    mask = []

    # Generate mask in the same order as the statistics were combined
    for subkey in used_action_keys_ordered:
        if subkey in action_modalities:
            subkey_config = action_modalities[subkey]

            # Get dimension count from shape
            if hasattr(subkey_config, 'shape') and len(subkey_config.shape) > 0:
                dim_count = subkey_config.shape[0]
            else:
                dim_count = 1

            # Normalize ALL action dimensions uniformly (including gripper).
            # Previous behavior: mask gripper as False (skip normalization).
            # This caused gripper targets {0,1} to be on a different scale than
            # other dims in [-1,1], making the regression loss drown out gripper signal.
            for _ in range(dim_count):
                mask.append(True)

    return mask
get_used_modality_keys
get_used_modality_keys(modality_keys: dict) -> tuple[list, list]

Extract used action and state keys from modality configuration.

Source code in AlphaBrain/dataloader/gr00t_lerobot/datasets.py
def get_used_modality_keys(modality_keys: dict) -> tuple[list, list]:
    """Extract used action and state keys from modality configuration."""
    used_action_keys = []
    used_state_keys = []

    # Extract action keys (remove "action." prefix)
    for action_key in modality_keys.get("action", []):
        if action_key.startswith("action."):
            clean_key = action_key.replace("action.", "")
            used_action_keys.append(clean_key)

    # Extract state keys (remove "state." prefix)  
    for state_key in modality_keys.get("state", []):
        if state_key.startswith("state."):
            clean_key = state_key.replace("state.", "")
            used_state_keys.append(clean_key)

    return used_action_keys, used_state_keys

embodiment_tags

EmbodimentTag

Bases: Enum

GR1 class-attribute instance-attribute
GR1 = 'gr1'

The GR1 dataset.

OXE_DROID class-attribute instance-attribute
OXE_DROID = 'oxe_droid'

The OxE Droid dataset.

OXE_BRIDGE class-attribute instance-attribute
OXE_BRIDGE = 'oxe_bridge'

The OxE Bridge dataset.

OXE_RT1 class-attribute instance-attribute
OXE_RT1 = 'oxe_rt1'

The OxE RT-1 dataset.

AGIBOT_GENIE1 class-attribute instance-attribute
AGIBOT_GENIE1 = 'agibot_genie1'

The AgiBot Genie-1 with gripper dataset.

NEW_EMBODIMENT class-attribute instance-attribute
NEW_EMBODIMENT = 'new_embodiment'

Any new embodiment for finetuning.

FRANKA class-attribute instance-attribute
FRANKA = 'franka'

The Franka Emika Panda robot.

mixtures

mixtures.py

Defines a registry of dataset mixtures and weights for the Open-X Embodiment Datasets. Each dataset is associated with a float "sampling weight"

schema

RotationType

Bases: Enum

Type of rotation representation

LeRobotModalityField

Bases: BaseModel

Metadata for a LeRobot modality field.

LeRobotStateActionMetadata

Bases: LeRobotModalityField

Metadata for a LeRobot modality.

LeRobotStateMetadata

Bases: LeRobotStateActionMetadata

Metadata for a LeRobot state modality.

LeRobotActionMetadata

Bases: LeRobotStateActionMetadata

Metadata for a LeRobot action modality.

LeRobotModalityMetadata

Bases: BaseModel

Metadata for a LeRobot modality.

get_key_meta
get_key_meta(key: str) -> LeRobotModalityField

Get the metadata for a key in the LeRobot modality metadata.

Parameters:

Name Type Description Default
key str

The key to get the metadata for.

required

Returns:

Name Type Description
LeRobotModalityField LeRobotModalityField

The metadata for the key.

Example

lerobot_modality_meta = LeRobotModalityMetadata.model_validate(U.load_json(modality_meta_path)) lerobot_modality_meta.get_key_meta("state.joint_shoulder_y") lerobot_modality_meta.get_key_meta("video.main_camera") lerobot_modality_meta.get_key_meta("annotation.human.action.task_description")

Source code in AlphaBrain/dataloader/gr00t_lerobot/schema.py
def get_key_meta(self, key: str) -> LeRobotModalityField:
    """Get the metadata for a key in the LeRobot modality metadata.

    Args:
        key (str): The key to get the metadata for.

    Returns:
        LeRobotModalityField: The metadata for the key.

    Example:
        lerobot_modality_meta = LeRobotModalityMetadata.model_validate(U.load_json(modality_meta_path))
        lerobot_modality_meta.get_key_meta("state.joint_shoulder_y")
        lerobot_modality_meta.get_key_meta("video.main_camera")
        lerobot_modality_meta.get_key_meta("annotation.human.action.task_description")
    """
    split_key = key.split(".")
    modality = split_key[0]
    subkey = ".".join(split_key[1:])
    if modality == "state":
        if subkey not in self.state:
            raise ValueError(
                f"Key: {key}, state key {subkey} not found in metadata, available state keys: {self.state.keys()}"
            )
        return self.state[subkey]
    elif modality == "action":
        if subkey not in self.action:
            raise ValueError(
                f"Key: {key}, action key {subkey} not found in metadata, available action keys: {self.action.keys()}"
            )
        return self.action[subkey]
    elif modality == "video":
        if subkey not in self.video:
            raise ValueError(
                f"Key: {key}, video key {subkey} not found in metadata, available video keys: {self.video.keys()}"
            )
        return self.video[subkey]
    elif modality == "annotation":
        assert (
            self.annotation is not None
        ), "Trying to get annotation metadata for a dataset with no annotations"
        if subkey not in self.annotation:
            raise ValueError(
                f"Key: {key}, annotation key {subkey} not found in metadata, available annotation keys: {self.annotation.keys()}"
            )
        return self.annotation[subkey]
    else:
        raise ValueError(f"Key: {key}, unexpected modality: {modality}")
VideoMetadata

Bases: BaseModel

Metadata of the video modality

DatasetMetadata

Bases: BaseModel

Metadata of the trainable dataset

Changes
  • Update to use the new RawCommitHashMetadataMetadata_V1_2

transform

base
ModalityTransform

Bases: BaseModel, ABC

Abstract class for transforming data modalities, e.g. video frame augmentation or action normalization.

set_metadata
set_metadata(dataset_metadata: DatasetMetadata)

Set the dataset metadata. This is useful for transforms that need to know the dataset metadata, e.g. to normalize actions. Subclasses can override this method if they need to do something more complex.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/base.py
def set_metadata(self, dataset_metadata: DatasetMetadata):
    """
    Set the dataset metadata. This is useful for transforms that need to know the dataset metadata, e.g. to normalize actions.
    Subclasses can override this method if they need to do something more complex.
    """
    self.dataset_metadata = dataset_metadata
apply abstractmethod
apply(data: dict[str, Any]) -> dict[str, Any]

Apply the transformation to the data corresponding to keys matching the apply_to regular expression and return the processed data.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/base.py
@abstractmethod
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
    """Apply the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
InvertibleModalityTransform

Bases: ModalityTransform

unapply abstractmethod
unapply(data: dict[str, Any]) -> dict[str, Any]

Reverse the transformation to the data corresponding to keys matching the apply_to regular expression and return the processed data.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/base.py
@abstractmethod
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
    """Reverse the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
ComposedModalityTransform

Bases: ModalityTransform

Compose multiple modality transforms.

concat
ConcatTransform

Bases: InvertibleModalityTransform

Concatenate the keys according to specified order.

get_state_action_dims
get_state_action_dims(key: str) -> int

Get the dimension of a state or action key from the dataset metadata.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/concat.py
def get_state_action_dims(self, key: str) -> int:
    """Get the dimension of a state or action key from the dataset metadata."""
    modality_config = self.get_modality_metadata(key)
    shape = modality_config.shape
    assert len(shape) == 1, f"{shape=}"
    return shape[0]
set_metadata
set_metadata(dataset_metadata: DatasetMetadata)

Set the metadata and compute the dimensions of the state and action keys.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/concat.py
def set_metadata(self, dataset_metadata: DatasetMetadata):
    """Set the metadata and compute the dimensions of the state and action keys."""
    super().set_metadata(dataset_metadata)
    # Pre-compute the dimensions of the state and action keys
    if self.action_concat_order is not None:
        for key in self.action_concat_order:
            self.action_dims[key] = self.get_state_action_dims(key)
    if self.state_concat_order is not None:
        for key in self.state_concat_order:
            self.state_dims[key] = self.get_state_action_dims(key)
state_action
RotationTransform
RotationTransform(from_rep='axis_angle', to_rep='rotation_6d')

Adapted from https://github.com/real-stanford/diffusion_policy/blob/548a52bbb105518058e27bf34dcf90bf6f73681a/diffusion_policy/model/common/rotation_transformer.py

Valid representations

Always use matrix as intermediate representation.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/state_action.py
def __init__(self, from_rep="axis_angle", to_rep="rotation_6d"):
    """
    Valid representations

    Always use matrix as intermediate representation.
    """
    if from_rep.startswith("euler_angles"):
        from_convention = from_rep.split("_")[-1]
        from_rep = "euler_angles"
        from_convention = from_convention.replace("r", "X").replace("p", "Y").replace("y", "Z")
    else:
        from_convention = None
    if to_rep.startswith("euler_angles"):
        to_convention = to_rep.split("_")[-1]
        to_rep = "euler_angles"
        to_convention = to_convention.replace("r", "X").replace("p", "Y").replace("y", "Z")
    else:
        to_convention = None
    assert from_rep != to_rep, f"from_rep and to_rep cannot be the same: {from_rep}"
    assert from_rep in self.valid_reps, f"Invalid from_rep: {from_rep}"
    assert to_rep in self.valid_reps, f"Invalid to_rep: {to_rep}"

    forward_funcs = list()
    inverse_funcs = list()

    if from_rep != "matrix":
        funcs = [getattr(pt, f"{from_rep}_to_matrix"), getattr(pt, f"matrix_to_{from_rep}")]
        if from_convention is not None:
            funcs = [functools.partial(func, convention=from_convention) for func in funcs]
        forward_funcs.append(funcs[0])
        inverse_funcs.append(funcs[1])

    if to_rep != "matrix":
        funcs = [getattr(pt, f"matrix_to_{to_rep}"), getattr(pt, f"{to_rep}_to_matrix")]
        if to_convention is not None:
            funcs = [functools.partial(func, convention=to_convention) for func in funcs]
        forward_funcs.append(funcs[0])
        inverse_funcs.append(funcs[1])

    inverse_funcs = inverse_funcs[::-1]

    self.forward_funcs = forward_funcs
    self.inverse_funcs = inverse_funcs
StateActionToTensor

Bases: InvertibleModalityTransform

Transforms states and actions to tensors.

StateActionTransform

Bases: InvertibleModalityTransform

Class for state or action transform.

Parameters:

Name Type Description Default
apply_to list[str]

The keys in the modality to load and transform.

required
normalization_modes dict[str, str]

The normalization modes for each state key. If a state key in apply_to is not present in the dictionary, it will not be normalized.

required
target_rotations dict[str, str]

The target representations for each state key. If a state key in apply_to is not present in the dictionary, it will not be rotated.

required
StateActionPerturbation

Bases: ModalityTransform

Class for state or action perturbation.

Parameters:

Name Type Description Default
apply_to list[str]

The keys in the modality to load and transform.

required
std float

Standard deviation of the noise to be added to the state or action.

required
StateActionDropout

Bases: ModalityTransform

Class for state or action dropout.

Parameters:

Name Type Description Default
apply_to list[str]

The keys in the modality to load and transform.

required
dropout_prob float

Probability of dropping out a state or action.

required
StateActionSinCosTransform

Bases: ModalityTransform

Class for state or action sin-cos transform.

Parameters:

Name Type Description Default
apply_to list[str]

The keys in the modality to load and transform.

required
video
VideoTransform
VideoCrop

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable

Get the transform for the given mode.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Name Type Description
Callable Callable

If mode is "train", return a random crop transform. If mode is "eval", return a center crop transform.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
    """Get the transform for the given mode.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable: If mode is "train", return a random crop transform. If mode is "eval", return a center crop transform.
    """
    # 1. Check the input resolution
    assert (
        len(set(self.original_resolutions.values())) == 1
    ), f"All video keys must have the same resolution, got: {self.original_resolutions}"
    if self.height is None:
        assert self.width is None, "Height and width must be either both provided or both None"
        self.width, self.height = self.original_resolutions[self.apply_to[0]]
    else:
        assert (
            self.width is not None
        ), "Height and width must be either both provided or both None"
    # 2. Create the transform
    size = (int(self.height * self.scale), int(self.width * self.scale))
    if self.backend == "torchvision":
        if mode == "train":
            return T.RandomCrop(size)
        elif mode == "eval":
            return T.CenterCrop(size)
        else:
            raise ValueError(f"Crop mode {mode} not supported")
    elif self.backend == "albumentations":
        if mode == "train":
            return A.RandomCrop(height=size[0], width=size[1], p=1)
        elif mode == "eval":
            return A.CenterCrop(height=size[0], width=size[1], p=1)
        else:
            raise ValueError(f"Crop mode {mode} not supported")
    else:
        raise ValueError(f"Backend {self.backend} not supported")
VideoResize

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable

Get the resize transform. Same transform for both train and eval.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Name Type Description
Callable Callable

The resize transform.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
    """Get the resize transform. Same transform for both train and eval.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable: The resize transform.
    """
    interpolation = self._get_interpolation(self.interpolation, self.backend)
    if interpolation is None:
        raise ValueError(
            f"Interpolation mode {self.interpolation} not supported for torchvision"
        )
    if self.backend == "torchvision":
        size = (self.height, self.width)
        return T.Resize(size, interpolation=interpolation, antialias=self.antialias)
    elif self.backend == "albumentations":
        return A.Resize(
            height=self.height,
            width=self.width,
            interpolation=interpolation,
            p=1,
        )
    else:
        raise ValueError(f"Backend {self.backend} not supported")
VideoRandomRotation

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable | None

Get the random rotation transform, only used in train mode.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Type Description
Callable | None

Callable | None: The random rotation transform. None for eval mode.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
    """Get the random rotation transform, only used in train mode.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable | None: The random rotation transform. None for eval mode.
    """
    if mode == "eval":
        return None
    interpolation = self._get_interpolation(self.interpolation, self.backend)
    if interpolation is None:
        raise ValueError(
            f"Interpolation mode {self.interpolation} not supported for torchvision"
        )
    if self.backend == "torchvision":
        return T.RandomRotation(self.degrees, interpolation=interpolation)  # type: ignore
    elif self.backend == "albumentations":
        return A.Rotate(limit=self.degrees, interpolation=interpolation, p=1)
    else:
        raise ValueError(f"Backend {self.backend} not supported")
VideoHorizontalFlip

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable | None

Get the horizontal flip transform, only used in train mode.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Type Description
Callable | None

Callable | None: If mode is "train", return a horizontal flip transform. If mode is "eval", return None.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
    """Get the horizontal flip transform, only used in train mode.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable | None: If mode is "train", return a horizontal flip transform. If mode is "eval", return None.
    """
    if mode == "eval":
        return None
    if self.backend == "torchvision":
        return T.RandomHorizontalFlip(self.p)
    elif self.backend == "albumentations":
        return A.HorizontalFlip(p=self.p)
    else:
        raise ValueError(f"Backend {self.backend} not supported")
VideoGrayscale

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable | None

Get the grayscale transform, only used in train mode.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Type Description
Callable | None

Callable | None: If mode is "train", return a grayscale transform. If mode is "eval", return None.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
    """Get the grayscale transform, only used in train mode.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable | None: If mode is "train", return a grayscale transform. If mode is "eval", return None.
    """
    if mode == "eval":
        return None
    if self.backend == "torchvision":
        return T.RandomGrayscale(self.p)
    elif self.backend == "albumentations":
        return A.ToGray(p=self.p)
    else:
        raise ValueError(f"Backend {self.backend} not supported")
VideoColorJitter

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable | None

Get the color jitter transform, only used in train mode.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Type Description
Callable | None

Callable | None: If mode is "train", return a color jitter transform. If mode is "eval", return None.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
    """Get the color jitter transform, only used in train mode.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable | None: If mode is "train", return a color jitter transform. If mode is "eval", return None.
    """
    if mode == "eval":
        return None
    if self.backend == "torchvision":
        return T.ColorJitter(
            brightness=self.brightness,
            contrast=self.contrast,
            saturation=self.saturation,
            hue=self.hue,
        )
    elif self.backend == "albumentations":
        return A.ColorJitter(
            brightness=self.brightness,
            contrast=self.contrast,
            saturation=self.saturation,
            hue=self.hue,
            p=1,
        )
    else:
        raise ValueError(f"Backend {self.backend} not supported")
VideoRandomGrayscale

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable | None

Get the grayscale transform, only used in train mode.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Type Description
Callable | None

Callable | None: If mode is "train", return a grayscale transform. If mode is "eval", return None.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
    """Get the grayscale transform, only used in train mode.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable | None: If mode is "train", return a grayscale transform. If mode is "eval", return None.
    """
    if mode == "eval":
        return None
    if self.backend == "torchvision":
        return T.RandomGrayscale(self.p)
    elif self.backend == "albumentations":
        return A.ToGray(p=self.p)
    else:
        raise ValueError(f"Backend {self.backend} not supported")
VideoRandomPosterize

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable | None

Get the posterize transform, only used in train mode.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Type Description
Callable | None

Callable | None: If mode is "train", return a posterize transform. If mode is "eval", return None.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
    """Get the posterize transform, only used in train mode.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable | None: If mode is "train", return a posterize transform. If mode is "eval", return None.
    """
    if mode == "eval":
        return None
    if self.backend == "torchvision":
        return T.RandomPosterize(bits=self.bits, p=self.p)
    elif self.backend == "albumentations":
        return A.Posterize(num_bits=self.bits, p=self.p)
    else:
        raise ValueError(f"Backend {self.backend} not supported")
VideoToTensor

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable

Get the to tensor transform. Same transform for both train and eval.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Name Type Description
Callable Callable

The to tensor transform.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
    """Get the to tensor transform. Same transform for both train and eval.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable: The to tensor transform.
    """
    if self.backend == "torchvision":
        return self.__class__.to_tensor
    else:
        raise ValueError(f"Backend {self.backend} not supported")
check_input
check_input(data: dict)

Check if the input data has the correct shape. Expected video shape: [T, H, W, C], dtype np.uint8

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def check_input(self, data: dict):
    """Check if the input data has the correct shape.
    Expected video shape: [T, H, W, C], dtype np.uint8
    """
    for key in self.apply_to:
        assert key in data, f"Key {key} not found in data. Available keys: {data.keys()}"
        assert data[key].ndim in [
            4,
            5,
        ], f"Video {key} must have 4 or 5 dimensions, got {data[key].ndim}"
        assert (
            data[key].dtype == np.uint8
        ), f"Video {key} must have dtype uint8, got {data[key].dtype}"
        input_resolution = data[key].shape[-3:-1][::-1]
        if key in self.original_resolutions:
            expected_resolution = self.original_resolutions[key]
        else:
            expected_resolution = input_resolution
        assert (
            input_resolution == expected_resolution
        ), f"Video {key} has invalid resolution {input_resolution}, expected {expected_resolution}. Full shape: {data[key].shape}"
to_tensor staticmethod
to_tensor(frames: ndarray) -> torch.Tensor

Convert numpy array to tensor efficiently.

Parameters:

Name Type Description Default
frames ndarray

numpy array of shape [T, H, W, C] in uint8 format

required

Returns: tensor of shape [T, C, H, W] in range [0, 1]

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
@staticmethod
def to_tensor(frames: np.ndarray) -> torch.Tensor:
    """Convert numpy array to tensor efficiently.

    Args:
        frames: numpy array of shape [T, H, W, C] in uint8 format
    Returns:
        tensor of shape [T, C, H, W] in range [0, 1]
    """
    frames_tensor = torch.from_numpy(frames).to(torch.float32) / 255.0
    return frames_tensor.permute(0, 3, 1, 2)  # [T, C, H, W]
VideoToNumpy

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable

Get the to numpy transform. Same transform for both train and eval.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Name Type Description
Callable Callable

The to numpy transform.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
    """Get the to numpy transform. Same transform for both train and eval.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable: The to numpy transform.
    """
    if self.backend == "torchvision":
        return self.__class__.to_numpy
    else:
        raise ValueError(f"Backend {self.backend} not supported")
to_numpy staticmethod
to_numpy(frames: Tensor) -> np.ndarray

Convert tensor back to numpy array efficiently.

Parameters:

Name Type Description Default
frames Tensor

tensor of shape [T, C, H, W] in range [0, 1]

required

Returns: numpy array of shape [T, H, W, C] in uint8 format

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
@staticmethod
def to_numpy(frames: torch.Tensor) -> np.ndarray:
    """Convert tensor back to numpy array efficiently.

    Args:
        frames: tensor of shape [T, C, H, W] in range [0, 1]
    Returns:
        numpy array of shape [T, H, W, C] in uint8 format
    """
    return (frames.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
VideoToPIL

Bases: VideoTransform

get_transform
get_transform(mode: Literal['train', 'eval'] = 'train') -> Callable

Get the to PIL transform. Same transform for both train and eval.

Parameters:

Name Type Description Default
mode Literal['train', 'eval']

The mode to get the transform for.

'train'

Returns:

Name Type Description
Callable Callable

The to PIL transform.

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
    """Get the to PIL transform. Same transform for both train and eval.

    Args:
        mode (Literal["train", "eval"]): The mode to get the transform for.

    Returns:
        Callable: The to PIL transform.
    """
    if self.backend == "torchvision":
        return self.__class__.to_pil
    else:
        raise ValueError(f"Backend {self.backend} not supported")
to_pil staticmethod
to_pil(frames: Tensor) -> Image.Image

Convert tensor back to PIL Image.

Parameters:

Name Type Description Default
frames Tensor

tensor of shape [T, C, H, W] in range [0, 1]

required

Returns: PIL Image of shape [T, H, W, C] in uint8 format

Source code in AlphaBrain/dataloader/gr00t_lerobot/transform/video.py
@staticmethod
def to_pil(frames: torch.Tensor) -> Image.Image:
    """Convert tensor back to PIL Image.

    Args:
        frames: tensor of shape [T, C, H, W] in range [0, 1]
    Returns:
        PIL Image of shape [T, H, W, C] in uint8 format
    """
    # video PIL format?
    return Image.fromarray((frames.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy())

video

get_frames_by_timestamps
get_frames_by_timestamps(video_path: str, timestamps: list[float] | ndarray, video_backend: str = 'decord', video_backend_kwargs: dict = {}) -> np.ndarray

Get frames from a video at specified timestamps. Args: video_path (str): Path to the video file. timestamps (list[int] | np.ndarray): Timestamps to retrieve frames for, in seconds. video_backend (str, optional): Video backend to use. Defaults to "decord". Returns: np.ndarray: Frames at the specified timestamps.

Source code in AlphaBrain/dataloader/gr00t_lerobot/video.py
def get_frames_by_timestamps(
    video_path: str,
    timestamps: list[float] | np.ndarray,
    video_backend: str = "decord",
    video_backend_kwargs: dict = {},
) -> np.ndarray:
    """Get frames from a video at specified timestamps.
    Args:
        video_path (str): Path to the video file.
        timestamps (list[int] | np.ndarray): Timestamps to retrieve frames for, in seconds.
        video_backend (str, optional): Video backend to use. Defaults to "decord".
    Returns:
        np.ndarray: Frames at the specified timestamps.
    """
    if video_backend == "decord":
        # For some GPUs, AV format data cannot be read
        if not DECORD_AVAILABLE:
            raise ImportError("decord is not available.")
        vr = decord.VideoReader(video_path, **video_backend_kwargs)
        num_frames = len(vr)
        # Retrieve the timestamps for each frame in the video
        frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames))
        # Map each requested timestamp to the closest frame index
        # Only take the first element of the frame_ts array which corresponds to start_seconds
        indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0)
        frames = vr.get_batch(indices)
        return frames.asnumpy()
    elif video_backend == "torchcodec":
        if not TORCHCODEC_AVAILABLE:
            raise ImportError("torchcodec is not available.")
        decoder = torchcodec.decoders.VideoDecoder(
            video_path, device="cpu", dimension_order="NHWC", num_ffmpeg_threads=0
        )
        return decoder.get_frames_played_at(seconds=timestamps).data.numpy()
    elif video_backend == "opencv":
        # Open the video file
        cap = cv2.VideoCapture(video_path, **video_backend_kwargs)
        if not cap.isOpened():
            raise ValueError(f"Unable to open video file: {video_path}")
        # Retrieve the total number of frames
        num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        # Calculate timestamps for each frame
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_ts = np.arange(num_frames) / fps
        frame_ts = frame_ts[:, np.newaxis]  # Reshape to (num_frames, 1) for broadcasting
        # Map each requested timestamp to the closest frame index
        indices = np.abs(frame_ts - timestamps).argmin(axis=0)
        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if not ret:
                raise ValueError(f"Unable to read frame at index {idx}")
            frames.append(frame)
        cap.release()
        frames = np.array(frames)
        return frames
    elif video_backend == "pyav":
        # Use PyAV directly without torchvision wrapper
        container = None
        try:
            container = av.open(video_path)
            stream = container.streams.video[0]

            # Get video properties
            time_base = float(stream.time_base)
            fps = float(stream.average_rate) if stream.average_rate else float(stream.guessed_rate)
            duration = float(stream.duration * time_base) if stream.duration else None

            loaded_frames = []

            for target_ts in timestamps:
                # Convert timestamp to pts (presentation timestamp)
                target_pts = int(target_ts / time_base)

                # Seek to the target timestamp (seek to keyframe before target)
                container.seek(target_pts, stream=stream, backward=True, any_frame=False)

                closest_frame = None
                closest_ts_diff = float('inf')

                _iter_count = 0
                _MAX_ITER = 500  # prevent infinite loop on corrupted/slow video
                for frame in container.decode(video=0):
                    _iter_count += 1
                    if _iter_count > _MAX_ITER:
                        break
                    current_ts = float(frame.pts * time_base)
                    current_diff = abs(current_ts - target_ts)

                    if current_diff < closest_ts_diff:
                        # Release the previous frame
                        if closest_frame is not None:
                            del closest_frame
                        closest_ts_diff = current_diff
                        closest_frame = frame

                    # If we've passed the target and diff is increasing, stop
                    if current_ts > target_ts and current_diff > closest_ts_diff:
                        break

                    # Also stop if we're significantly past the target (optimization)
                    if current_ts > target_ts + 1.0:
                        break

                if closest_frame is not None:
                    frame_data = closest_frame.to_ndarray(format="rgb24")
                    loaded_frames.append(frame_data)
                    del closest_frame
                else:
                    raise ValueError(f"Unable to find frame at timestamp {target_ts}")

            frames = np.array(loaded_frames)
            return frames

        finally:
            if container is not None:
                container.close()
                container = None

    elif video_backend == "torchvision_av":
        torchvision.set_video_backend("pyav")
        loaded_frames = []
        loaded_ts = []

        reader = None
        try:
            reader = torchvision.io.VideoReader(video_path, "video")

            for target_ts in timestamps:
                # Reset reader state
                reader.seek(target_ts, keyframes_only=True)

                closest_frame = None
                closest_ts_diff = float('inf')
                _iter_count = 0
                _MAX_ITER = 500  # prevent infinite loop on corrupted/slow video

                for frame in reader:
                    _iter_count += 1
                    if _iter_count > _MAX_ITER:
                        break
                    current_ts = frame["pts"]
                    current_diff = abs(current_ts - target_ts)

                    if closest_frame is None:
                        closest_frame = frame

                    if current_diff < closest_ts_diff:
                        # Release the previous frame
                        if closest_frame is not None:
                            del closest_frame
                        closest_ts_diff = current_diff
                        closest_frame = frame
                    else:
                        # The time difference starts to increase, stop searching
                        break

                if closest_frame is not None:
                    frame_data = closest_frame["data"]
                    if isinstance(frame_data, torch.Tensor):
                        frame_data = frame_data.cpu().numpy()
                    loaded_frames.append(frame_data)
                    loaded_ts.append(closest_frame["pts"])

                    # Immediately release frame reference
                    del closest_frame

        finally:
            # Thoroughly clean resources
            if reader is not None:
                if hasattr(reader, '_c'):
                    reader._c = None
                if hasattr(reader, 'container'):
                    reader.container.close()
                    reader.container = None

        frames = np.array(loaded_frames)
        return frames.transpose(0, 2, 3, 1)
    else:
        raise NotImplementedError
get_all_frames
get_all_frames(video_path: str, video_backend: str = 'decord', video_backend_kwargs: dict = {}, resize_size: tuple[int, int] | None = None) -> np.ndarray

Get all frames from a video. Args: video_path (str): Path to the video file. video_backend (str, optional): Video backend to use. Defaults to "decord". video_backend_kwargs (dict, optional): Keyword arguments for the video backend. resize_size (tuple[int, int], optional): Resize size for the frames. Defaults to None.

Source code in AlphaBrain/dataloader/gr00t_lerobot/video.py
def get_all_frames(
    video_path: str,
    video_backend: str = "decord",
    video_backend_kwargs: dict = {},
    resize_size: tuple[int, int] | None = None,
) -> np.ndarray:
    """Get all frames from a video.
    Args:
        video_path (str): Path to the video file.
        video_backend (str, optional): Video backend to use. Defaults to "decord".
        video_backend_kwargs (dict, optional): Keyword arguments for the video backend.
        resize_size (tuple[int, int], optional): Resize size for the frames. Defaults to None.
    """
    if video_backend == "decord":
        if not DECORD_AVAILABLE:
            raise ImportError("decord is not available.")
        vr = decord.VideoReader(video_path, **video_backend_kwargs)
        frames = vr.get_batch(range(len(vr))).asnumpy()
    elif video_backend == "torchcodec":
        if not TORCHCODEC_AVAILABLE:
            raise ImportError("torchcodec is not available.")
        decoder = torchcodec.decoders.VideoDecoder(
            video_path, device="cpu", dimension_order="NHWC", num_ffmpeg_threads=0
        )
        frames = decoder.get_frames_at(indices=range(len(decoder)))
        return frames.data.numpy(), frames.pts_seconds.numpy()
    elif video_backend == "pyav":
        container = av.open(video_path)
        frames = []
        for frame in container.decode(video=0):
            frame = frame.to_ndarray(format="rgb24")
            frames.append(frame)
        frames = np.array(frames)
    elif video_backend == "torchvision_av":
        # set backend and reader
        torchvision.set_video_backend("pyav")
        reader = torchvision.io.VideoReader(video_path, "video")
        frames = []
        for frame in reader:
            frames.append(frame["data"].numpy())
        frames = np.array(frames)
        frames = frames.transpose(0, 2, 3, 1)
    else:
        raise NotImplementedError(f"Video backend {video_backend} not implemented")
    # resize frames if specified
    if resize_size is not None:
        frames = [cv2.resize(frame, resize_size) for frame in frames]
        frames = np.array(frames)
    return frames

Qwen-VL LLaVA-JSON subpackage

qwenvl_llavajson

rope2d

get_rope_index_25
get_rope_index_25(spatial_merge_size: Optional[int] = 2, input_ids: Optional[LongTensor] = None, image_grid_thw: Optional[LongTensor] = None, video_grid_thw: Optional[LongTensor] = None, second_per_grid_ts: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]

Calculate the 3D rope index based on image and video's temporal, height and width in LLM.

Explanation

Each embedding sequence contains vision embedding and text embedding or just contains text embedding.

For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. Examples: input_ids: [T T T T T], here T is for text. temporal position_ids: [0, 1, 2, 3, 4] height position_ids: [0, 1, 2, 3, 4] width position_ids: [0, 1, 2, 3, 4]

For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part and 1D rotary position embedding for text part. Examples: Temporal (Time): 3 patches, representing different segments of the video in time. Height: 2 patches, dividing each frame vertically. Width: 2 patches, dividing each frame horizontally. We also have some important parameters: fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] text temporal position_ids: [101, 102, 103, 104, 105] text height position_ids: [101, 102, 103, 104, 105] text width position_ids: [101, 102, 103, 104, 105] Here we calculate the text start position_ids as the max vision position_ids plus 1.

Parameters:

Name Type Description Default
input_ids `torch.LongTensor` of shape `(batch_size, sequence_length)`

Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.

None
image_grid_thw `torch.LongTensor` of shape `(num_images, 3)`, *optional*

The temporal, height and width of feature shape of each image in LLM.

None
video_grid_thw `torch.LongTensor` of shape `(num_videos, 3)`, *optional*

The temporal, height and width of feature shape of each video in LLM.

None
second_per_grid_ts `torch.Tensor` of shape `(num_videos)`, *optional*

The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.

None
attention_mask `torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*

Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:

  • 1 for tokens that are not masked,
  • 0 for tokens that are masked.
None

Returns:

Type Description
Tensor

position_ids (torch.LongTensor of shape (3, batch_size, sequence_length))

Tensor

mrope_position_deltas (torch.Tensor of shape (batch_size))

Source code in AlphaBrain/dataloader/qwenvl_llavajson/rope2d.py
def get_rope_index_25(
    spatial_merge_size: Optional[int] = 2,
    input_ids: Optional[torch.LongTensor] = None,
    image_grid_thw: Optional[torch.LongTensor] = None,
    video_grid_thw: Optional[torch.LongTensor] = None,
    second_per_grid_ts: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Calculate the 3D rope index based on image and video's temporal, height and width in LLM.

    Explanation:
        Each embedding sequence contains vision embedding and text embedding or just contains text embedding.

        For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
        Examples:
            input_ids: [T T T T T], here T is for text.
            temporal position_ids: [0, 1, 2, 3, 4]
            height position_ids: [0, 1, 2, 3, 4]
            width position_ids: [0, 1, 2, 3, 4]

        For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
        and 1D rotary position embedding for text part.
        Examples:
            Temporal (Time): 3 patches, representing different segments of the video in time.
            Height: 2 patches, dividing each frame vertically.
            Width: 2 patches, dividing each frame horizontally.
            We also have some important parameters:
            fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
            tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
            temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
            interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
            input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
            vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
            vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
            vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
            text temporal position_ids: [101, 102, 103, 104, 105]
            text height position_ids: [101, 102, 103, 104, 105]
            text width position_ids: [101, 102, 103, 104, 105]
            Here we calculate the text start position_ids as the max vision position_ids plus 1.

    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.
        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
            The temporal, height and width of feature shape of each image in LLM.
        video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
            The temporal, height and width of feature shape of each video in LLM.
        second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
            The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

    Returns:
        position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
        mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
    """
    image_token_id = 151655
    video_token_id = 151656
    vision_start_token_id = 151652
    mrope_position_deltas = []
    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
        total_input_ids = input_ids
        if attention_mask is None:
            attention_mask = torch.ones_like(total_input_ids)
        position_ids = torch.ones(
            3,
            input_ids.shape[0],
            input_ids.shape[1],
            dtype=input_ids.dtype,
            device=input_ids.device,
        )
        image_index, video_index = 0, 0
        attention_mask = attention_mask.to(total_input_ids.device)
        for i, input_ids in enumerate(total_input_ids):
            input_ids = input_ids[attention_mask[i] == 1]
            image_nums, video_nums = 0, 0
            vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
            vision_tokens = input_ids[vision_start_indices + 1]
            image_nums = (vision_tokens == image_token_id).sum()
            video_nums = (vision_tokens == video_token_id).sum()
            input_tokens = input_ids.tolist()
            llm_pos_ids_list: list = []
            st = 0
            remain_images, remain_videos = image_nums, video_nums
            for _ in range(image_nums + video_nums):
                if image_token_id in input_tokens and remain_images > 0:
                    ed_image = input_tokens.index(image_token_id, st)
                else:
                    ed_image = len(input_tokens) + 1
                if video_token_id in input_tokens and remain_videos > 0:
                    ed_video = input_tokens.index(video_token_id, st)
                else:
                    ed_video = len(input_tokens) + 1
                if ed_image < ed_video:
                    t, h, w = (
                        image_grid_thw[image_index][0],
                        image_grid_thw[image_index][1],
                        image_grid_thw[image_index][2],
                    )
                    second_per_grid_t = 0
                    image_index += 1
                    remain_images -= 1
                    ed = ed_image

                else:
                    t, h, w = (
                        video_grid_thw[video_index][0],
                        video_grid_thw[video_index][1],
                        video_grid_thw[video_index][2],
                    )
                    if second_per_grid_ts is not None:
                        second_per_grid_t = second_per_grid_ts[video_index]
                    else:
                        second_per_grid_t = 1.0
                    video_index += 1
                    remain_videos -= 1
                    ed = ed_video
                llm_grid_t, llm_grid_h, llm_grid_w = (
                    t.item(),
                    h.item() // spatial_merge_size,
                    w.item() // spatial_merge_size,
                )
                text_len = ed - st

                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

                range_tensor = torch.arange(llm_grid_t).view(-1, 1)
                expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)

                time_tensor = expanded_range * second_per_grid_t * 2

                time_tensor_long = time_tensor.long()
                t_index = time_tensor_long.flatten()

                h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
                w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
                llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
                st = ed + llm_grid_t * llm_grid_h * llm_grid_w

            if st < len(input_tokens):
                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                text_len = len(input_tokens) - st
                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

            llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
            position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
            mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
        mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
        return position_ids, mrope_position_deltas
    else:
        if attention_mask is not None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
            max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
            mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
        else:
            position_ids = (
                torch.arange(input_ids.shape[1], device=input_ids.device)
                .view(1, 1, -1)
                .expand(3, input_ids.shape[0], -1)
            )
            mrope_position_deltas = torch.zeros(
                [input_ids.shape[0], 1],
                device=input_ids.device,
                dtype=input_ids.dtype,
            )

        return position_ids, mrope_position_deltas
get_rope_index_2
get_rope_index_2(spatial_merge_size: Optional[int] = 2, input_ids: Optional[LongTensor] = None, image_grid_thw: Optional[LongTensor] = None, video_grid_thw: Optional[LongTensor] = None, second_per_grid_ts: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]

Calculate the 3D rope index based on image and video's temporal, height and width in LLM.

Explanation

Each embedding sequence contains vision embedding and text embedding or just contains text embedding.

For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. Examples: input_ids: [T T T T T], here T is for text. temporal position_ids: [0, 1, 2, 3, 4] height position_ids: [0, 1, 2, 3, 4] width position_ids: [0, 1, 2, 3, 4]

For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part and 1D rotary position embeddin for text part. Examples: Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] text temporal position_ids: [3, 4, 5, 6, 7] text height position_ids: [3, 4, 5, 6, 7] text width position_ids: [3, 4, 5, 6, 7] Here we calculate the text start position_ids as the max vision position_ids plus 1.

Parameters:

Name Type Description Default
input_ids `torch.LongTensor` of shape `(batch_size, sequence_length)`

Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.

None
image_grid_thw `torch.LongTensor` of shape `(num_images, 3)`, *optional*

The temporal, height and width of feature shape of each image in LLM.

None
video_grid_thw `torch.LongTensor` of shape `(num_videos, 3)`, *optional*

The temporal, height and width of feature shape of each video in LLM.

None
attention_mask `torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*

Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:

  • 1 for tokens that are not masked,
  • 0 for tokens that are masked.
None

Returns:

Type Description
Tensor

position_ids (torch.LongTensor of shape (3, batch_size, sequence_length))

Tensor

mrope_position_deltas (torch.Tensor of shape (batch_size))

Source code in AlphaBrain/dataloader/qwenvl_llavajson/rope2d.py
def get_rope_index_2(
    spatial_merge_size: Optional[int] = 2,
    input_ids: Optional[torch.LongTensor] = None,
    image_grid_thw: Optional[torch.LongTensor] = None,
    video_grid_thw: Optional[torch.LongTensor] = None,
    second_per_grid_ts: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Calculate the 3D rope index based on image and video's temporal, height and width in LLM.

    Explanation:
        Each embedding sequence contains vision embedding and text embedding or just contains text embedding.

        For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
        Examples:
            input_ids: [T T T T T], here T is for text.
            temporal position_ids: [0, 1, 2, 3, 4]
            height position_ids: [0, 1, 2, 3, 4]
            width position_ids: [0, 1, 2, 3, 4]

        For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
        and 1D rotary position embeddin for text part.
        Examples:
            Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
            input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
            vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
            vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
            vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
            text temporal position_ids: [3, 4, 5, 6, 7]
            text height position_ids: [3, 4, 5, 6, 7]
            text width position_ids: [3, 4, 5, 6, 7]
            Here we calculate the text start position_ids as the max vision position_ids plus 1.

    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.
        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
            The temporal, height and width of feature shape of each image in LLM.
        video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
            The temporal, height and width of feature shape of each video in LLM.
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

    Returns:
        position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
        mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
    """
    image_token_id = 151655
    video_token_id = 151656
    vision_start_token_id = 151652
    mrope_position_deltas = []
    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
        total_input_ids = input_ids
        if attention_mask is None:
            attention_mask = torch.ones_like(total_input_ids)
        position_ids = torch.ones(
            3,
            input_ids.shape[0],
            input_ids.shape[1],
            dtype=input_ids.dtype,
            device=input_ids.device,
        )
        image_index, video_index = 0, 0
        for i, input_ids in enumerate(total_input_ids):
            input_ids = input_ids[attention_mask[i] == 1]
            image_nums, video_nums = 0, 0
            vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
            vision_tokens = input_ids[vision_start_indices + 1]
            image_nums = (vision_tokens == image_token_id).sum()
            video_nums = (vision_tokens == video_token_id).sum()
            input_tokens = input_ids.tolist()
            llm_pos_ids_list: list = []
            st = 0
            remain_images, remain_videos = image_nums, video_nums
            for _ in range(image_nums + video_nums):
                if image_token_id in input_tokens and remain_images > 0:
                    ed_image = input_tokens.index(image_token_id, st)
                else:
                    ed_image = len(input_tokens) + 1
                if video_token_id in input_tokens and remain_videos > 0:
                    ed_video = input_tokens.index(video_token_id, st)
                else:
                    ed_video = len(input_tokens) + 1
                if ed_image < ed_video:
                    t, h, w = (
                        image_grid_thw[image_index][0],
                        image_grid_thw[image_index][1],
                        image_grid_thw[image_index][2],
                    )
                    image_index += 1
                    remain_images -= 1
                    ed = ed_image
                else:
                    t, h, w = (
                        video_grid_thw[video_index][0],
                        video_grid_thw[video_index][1],
                        video_grid_thw[video_index][2],
                    )
                    video_index += 1
                    remain_videos -= 1
                    ed = ed_video
                llm_grid_t, llm_grid_h, llm_grid_w = (
                    t.item(),
                    h.item() // spatial_merge_size,
                    w.item() // spatial_merge_size,
                )
                text_len = ed - st

                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

                t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
                h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
                w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
                llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
                st = ed + llm_grid_t * llm_grid_h * llm_grid_w

            if st < len(input_tokens):
                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                text_len = len(input_tokens) - st
                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

            llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
            position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
            mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
        mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
        return position_ids, mrope_position_deltas
    else:
        if attention_mask is not None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
            max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
            mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
        else:
            position_ids = (
                torch.arange(input_ids.shape[1], device=input_ids.device)
                .view(1, 1, -1)
                .expand(3, input_ids.shape[0], -1)
            )
            mrope_position_deltas = torch.zeros(
                [input_ids.shape[0], 1],
                device=input_ids.device,
                dtype=input_ids.dtype,
            )

        return position_ids, mrope_position_deltas