Training › Reinforcement Learning¶
Source path: AlphaBrain/training/reinforcement_learning/
Full implementation of VLA online RL training (RLActionToken — RL Token). Paper: RL Token: Bootstrapping Online RL with VLA Models (Physical Intelligence).
Layout:
- algos/RLActionToken/ — encoder/decoder, actor-critic, trainer, fast rollout
- common/ — rollout, replay buffer, checkpoint I/O
- envs/ — LIBERO environment, persistent env pool, env workers
- eval/ — LIBERO evaluation and shard aggregation
- trainers/ — on-policy / off-policy / pretrain entrypoints
Top-level re-exports¶
reinforcement_learning ¶
RLActionToken algorithm¶
Encoder / Decoder¶
action_token_encoder_decoder ¶
ActionToken Encoder-Decoder: Information bottleneck between frozen VLA and small RL network.
Inspired by the "RL Token" paper (Physical Intelligence, 2026), but with deviations from the paper's construction: - Encoder input is the VLA's action-query hidden states (M × H) gathered at the action-token positions, not the full image-token sequence (N × H) as in the paper's Fig. 2. - An extra Linear(H → bottleneck_dim) projection compresses per-token dim (e.g. 2048 → 256); the paper keeps the RL token at the VLA hidden dim. - The decoder is a self-attention transformer with a causal mask and a prefix token, not the encoder-decoder cross-attention structure of the paper's Eq. 2. A faithful paper-accurate reimplementation is still under test.
Paper Eq. 1 — Encoder: z_rl = g_φ([z_{1:M}, e_rl])_{M+1} Append learnable embedding e_rl to VLA token sequence, run through self-attention encoder transformer, take the e_rl position output.
Paper Eq. 2 — Decoder (autoregressive reconstruction): L_ro = E[ Σ_{i=1}^{M} ‖h_φ(d_φ([z_rl, sg(z_{1:i-1})]))_i − sg(z_i)‖² ] Reconstruct VLA tokens autoregressively from z_rl to enforce information preservation in the bottleneck.
ActionTokenEncoder ¶
ActionTokenEncoder(input_dim: int = 2048, bottleneck_dim: int = 256, num_heads: int = 4, num_layers: int = 2, dropout: float = 0.0)
Bases: Module
Paper Eq. 1: Compress VLA action_queries (B, M, H) → rl_token (B, 1, D).
Appends a learnable e_rl to the token sequence and processes with self-attention (TransformerEncoderLayer). The output at the e_rl position is projected to the bottleneck dimension.
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
forward ¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action_queries | Tensor | (B, M, H) from frozen VLA | required |
Returns: rl_token: (B, 1, D_bottleneck)
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
ActionTokenDecoder ¶
ActionTokenDecoder(bottleneck_dim: int = 256, output_dim: int = 2048, chunk_len: int = 8, num_heads: int = 4, num_layers: int = 2, dropout: float = 0.0)
Bases: Module
Paper Eq. 2: Autoregressive reconstruction of VLA tokens from z_rl.
L_ro = E[ Σ_i ‖h_φ(d_φ([z_rl, sg(z_{1:i-1})]))_i − sg(z_i)‖² ]
The decoder takes z_rl as prefix, and autoregressively reconstructs each VLA token conditioned on z_rl and previously reconstructed tokens. A causal mask ensures position i can only attend to positions < i (plus the z_rl prefix which is always visible).
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
forward ¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rl_token | Tensor | (B, 1, D_bottleneck) | required |
target_tokens | Tensor | (B, M, H) stop-gradient VLA tokens for teacher forcing. If None, uses learned positional embeddings (inference mode). | None |
Returns: reconstructed: (B, M, H)
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
ActionTokenEncoderDecoder ¶
ActionTokenEncoderDecoder(input_dim: int = 2048, bottleneck_dim: int = 256, chunk_len: int = 8, num_heads: int = 4, encoder_layers: int = 2, decoder_layers: int = 2, dropout: float = 0.0)
Bases: Module
Combined Encoder-Decoder for ActionToken pretraining.
Training: autoregressive reconstruction with teacher forcing. Inference: encoder only (decoder not used during RL).
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
forward ¶
Full encode-decode pass with autoregressive reconstruction loss.
Paper Eq. 2: L_ro = E[ Σ_i ‖reconstructed_i − sg(z_i)‖² ]
Returns:
| Name | Type | Description |
|---|---|---|
rl_token | (B, 1, D) | |
recon_loss | scalar MSE reconstruction loss |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
Actor / Critic¶
action_token_actor_critic ¶
ActionToken Actor-Critic, following the RL Token paper (Physical Intelligence) closely on the actor/critic side (the deviations from the paper live in action_token_encoder_decoder.py and action_token_trainer.py).
Key design choices from the paper
- Actor (Eq. 4): π_θ(a | x, ã) = N(μ_θ(x, ã), σ²I) The actor DIRECTLY outputs the action chunk, conditioned on (rl_token, vla_ref). VLA reference is an INPUT to the network, NOT a structural residual. BC regularization β‖a - ã‖² in the LOSS keeps actions close to VLA. Reference-action dropout (50%) prevents identity collapse.
- Critic (Eq. 3): Q(s, a) — twin Q-networks (TD3-style).
ActionTokenActor ¶
ActionTokenActor(bottleneck_dim: int = 256, action_dim: int = 7, chunk_len: int = 8, hidden_dim: int = 256, ref_dropout: float = 0.5, fixed_std: float = 0.1, prop_dim: int = 0)
Bases: Module
ActionToken actor from paper (Eq. 4-5).
π_θ(a_{1:C} | x, ã_{1:C}) = N(μ_θ(x, ã_{1:C}), σ²I)
The network takes (rl_token, vla_reference_action) as input and DIRECTLY outputs the full action chunk. The VLA reference is just a conditioning signal — the BC regularization in the loss (not in the architecture) keeps the output close to VLA.
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
forward ¶
forward(rl_token: Tensor, vla_action: Tensor, prop_state: Tensor = None, deterministic: bool = False)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rl_token | Tensor | (B, 1, D) or (B, D) | required |
vla_action | Tensor | (B, chunk_len, action_dim) — VLA reference | required |
prop_state | Tensor | (B, prop_dim) — proprioceptive state (eef_pos+axisangle+gripper) | None |
Returns: action: (B, chunk_len, action_dim) log_prob: (B,) or None if deterministic
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
log_prob_of ¶
log_prob_of(rl_token: Tensor, vla_action: Tensor, taken_action: Tensor, prop_state: Tensor = None) -> torch.Tensor
Compute log_prob of a previously taken action under current policy.
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
ActionTokenQCritic ¶
ActionTokenQCritic(bottleneck_dim: int = 256, action_dim: int = 7, chunk_len: int = 8, hidden_dim: int = 256, prop_dim: int = 0)
Bases: Module
Twin Q-critic from RL Token paper (Eq. 3, following TD3).
Q_ψ(x, a_{1:C}) takes the RL token state AND the action chunk as input. Contains two independent Q-networks; use min(Q1, Q2) for target values.
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
forward ¶
Returns: q1: (B,), q2: (B,)
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
q1_forward ¶
Single Q1 forward (used for actor loss to save compute).
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
ActionTokenCritic ¶
Bases: Module
State value estimator V(s) from rl_token. (Legacy, for PPO path only.)
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
soft_update_target ¶
Polyak averaging: target = (1 - tau) * target + tau * source.
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
Trainer (loss / update)¶
action_token_trainer ¶
ActionToken Trainer: Two-phase training for the RLActionToken variant.
Phase 1 — Encoder Pretraining: Freeze VLA, train encoder-decoder via reconstruction loss on rollout data.
Phase 2 — Actor-Critic RL: Freeze VLA, use pretrained encoder. Rollout with actor (action editing), update actor + critic via PPO-clip. Optional encoder fine-tuning with reconstruction regularization.
BatchInferenceServer ¶
BatchInferenceServer(frozen_vla, encoder, actor, critic, device: str, max_batch_size: int = 64, batch_timeout_s: float = 0.005, actor_chunk_len: int = None)
Batched VLA + encoder + actor + critic inference for rollout.
Problem with model_lock (old approach): - Each env thread acquires lock → VLA forward batch=1 → releases lock - GPU processes one observation at a time regardless of how many envs are running - With chunk_len=8, GPU is idle 7/8 of the time (CPU env.step dominates)
Solution
- All env threads submit (img_pair, instruction) to a request queue and block
- A single background thread drains the queue and runs ONE batched GPU forward
- All waiting threads get results simultaneously → GPU utilization scales with num_envs
Cross-task batching
- In multi-task mode, multiple tasks on the same GPU share ONE server
- When task-0 envs and task-5 envs both need VLA inference, they batch together
- Effective batch size = n_tasks_per_gpu × n_envs_per_task
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
infer ¶
Called from env threads. Blocks until the batch result is ready.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
img_pair | list | [primary_img, wrist_img] (numpy arrays, already flipped) | required |
instruction | str | task language description | required |
prop_state | np.ndarray or torch.Tensor (prop_dim,) — proprioceptive state | None |
Returns:
| Type | Description |
|---|---|
| (rl_token_cpu, vla_action_cpu, action_cpu, log_prob_float, value_float) | |
| tensors are (1, ...) on CPU. |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
ActionTokenStepRecord dataclass ¶
ActionTokenStepRecord(rl_token: Tensor, vla_action: Tensor, action_taken: Tensor, old_log_prob: float, value: float = 0.0, prop_state: Optional[Tensor] = None, sub_tokens: list = list(), images: Optional[list] = None, instruction: Optional[str] = None)
One inference step during RLActionToken rollout.
pretrain_encoder_step ¶
pretrain_encoder_step(frozen_vla, enc_dec: ActionTokenEncoderDecoder, batch_images: list, instructions: list, device: str = 'cuda')
One pretraining step: VLA forward → encoder → decoder → reconstruction loss.
Returns:
| Name | Type | Description |
|---|---|---|
recon_loss | scalar tensor (with grad on enc_dec params only) |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
collect_observations_fast ¶
collect_observations_fast(suite_name: str, task_id: int, n_observations: int, steps_per_env: int = 20, num_envs: int = 4, n_initial_states: int = 50, libero_python: str = None, seed: int = 42) -> list
Lightweight observation collection for encoder pretraining.
Instead of running full episodes (300+ steps each, 99% CPU idle), just reset envs to diverse initial states and take a few random steps. Returns list of (images, instruction) tuples ready for VLA forward.
Much faster than collect_group: no VLA forward during collection, no action caching, just random exploration for observation diversity.
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 | |
extract_action_queries_from_obs ¶
extract_action_queries_from_obs(frozen_vla, observations: list, batch_size: int = 16, device: str = 'cuda') -> torch.Tensor
Batch-extract action_queries from observations via frozen VLA.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observations | list | list of (images, instruction) where images=[primary, wrist] | required |
Returns:
| Name | Type | Description |
|---|---|---|
all_queries | Tensor | (N, chunk_len, H) tensor on device |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
extract_action_queries_dataset ¶
extract_action_queries_dataset(frozen_vla, episodes: list, batch_size: int = 16, device: str = 'cuda') -> torch.Tensor
Batch-extract all action_queries from rollout episodes via frozen VLA. (Legacy: used when full episodes are already collected.)
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
action_token_collect_group ¶
action_token_collect_group(frozen_vla, encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, critic: ActionTokenCritic, suite_name: str, task_id: int, n_initial_states: int, action_norm_stats: dict, max_steps: int, chunk_len: int, G: int = 8, libero_python: Optional[str] = None, seed: int = 42, num_steps_wait: int = 10, device: str = 'cuda', video_dir: Optional[str] = None, num_envs: int = 4, group_idx: int = 0, batch_server: Optional[BatchInferenceServer] = None, store_images: bool = False, group_size: int = 1, reward_coef: float = 1.0) -> List[ActionTokenEpisode]
Collect G episodes using RLActionToken policy.
Uses BatchInferenceServer for GPU inference: all num_envs env threads submit requests concurrently; a single background thread batches them into one GPU forward pass. This maximizes GPU utilization vs. the old model_lock (batch=1).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_server | Optional[BatchInferenceServer] | shared server for this GPU (created by caller for cross-task batching). If None, creates a local server for this call only. | None |
group_size | int | number of trajectories per initial state. G episodes are split into G//group_size unique states, each repeated group_size times. Default 1 = legacy behavior (no repeat). | 1 |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 | |
compute_action_token_gae ¶
compute_action_token_gae(episode: ActionTokenEpisode, gamma: float = 0.99, gae_lambda: float = 0.95)
Compute GAE advantages and returns for a single RLActionToken episode.
Returns:
| Name | Type | Description |
|---|---|---|
advantages | list of floats (len = finish_step) | |
returns | list of floats (len = finish_step) |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
action_token_ppo_loss ¶
action_token_ppo_loss(encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, critic: ActionTokenCritic, episodes: List[ActionTokenEpisode], gamma: float = 0.99, gae_lambda: float = 0.95, clip_eps: float = 0.2, vf_coef: float = 0.5, recon_loss_coef: float = 0.0, frozen_vla=None, device: str = 'cuda')
Compute PPO loss on a batch of RLActionToken episodes.
Only encoder + actor + critic have gradients. Optionally add reconstruction loss as regularizer.
Returns:
| Name | Type | Description |
|---|---|---|
loss | scalar tensor | |
stats | dict with training metrics |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 | |
push_episodes_to_buffer ¶
push_episodes_to_buffer(episodes: List[ActionTokenEpisode], replay_buffer: ReplayBuffer, gamma_per_step: float = 0.99)
Convert episode step_records into (s, a, r, s', done) transitions and push them into the replay buffer.
Chunk subsampling (paper, stride=2): For each executed chunk, we create transitions starting from positions [0, 2, 4, 6] within the chunk. Position-p transition: - state: rl_token at obs[p] within this chunk - action: action_taken[p:C] || next_chunk.action_taken[0:p] (length-C cross-chunk slice) - next_state: rl_token at obs[p] within the NEXT chunk - done: True only for terminal chunk
Reward (paper Eq. 3): Q̂ = Σ_{t'=0}^{C-1} γ^{t'} r_{t+t'} + γ^C Q_next. For sparse reward, only the terminal chunk has non-zero reward. At stride position p in terminal chunk (done at step d): reward = γ^(d-p) * episode_reward (discounted by steps from p to done)
Reward scheme: success=1.0, failure=0.0 (paper scheme)
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 | |
vla_finetune_step ¶
vla_finetune_step(vla, encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, q_critic: ActionTokenQCritic, episodes: List[ActionTokenEpisode], beta: float = 0.1, device: str = 'cuda', micro_batch: int = 4)
Full fine-tune: re-run VLA forward on stored images with gradients enabled.
Gradient path: actor_loss → actor → rl_token → encoder → action_queries → VLA The critic is frozen during this step (only provides Q signal, no param update).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
episodes | List[ActionTokenEpisode] | current iteration's episodes (must have .images stored) | required |
micro_batch | int | VLA forward batch size (controls GPU memory) | 4 |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 | |
action_token_td_critic_update ¶
action_token_td_critic_update(actor: ActionTokenActor, q_critic: ActionTokenQCritic, target_q_critic: ActionTokenQCritic, replay_buffer: ReplayBuffer, batch_size: int = 256, gamma: float = 0.99, device: str = 'cuda', target_noise_std: float = 0.2, target_noise_clip: float = 0.5, n_tasks: int = 0, target_actor: ActionTokenActor = None)
TD3-style twin-Q critic update from replay buffer (Eq. 3 in paper).
Target: Q̂ = Σ γ^t' r_t' + γ^C * min(Q1', Q2')(s', a') where a' ~ π_target(·|s', ã') + clipped noise (target policy smoothing)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_tasks | int | if > 0, use per-task stratified sampling for balanced multi-task update. | 0 |
target_actor | ActionTokenActor | Polyak-averaged actor for computing next actions (TD3). Falls back to online actor if None. | None |
Returns:
| Name | Type | Description |
|---|---|---|
critic_loss | scalar tensor (with grad on q_critic params) | |
stats | dict |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 | |
action_token_td_actor_update ¶
action_token_td_actor_update(actor: ActionTokenActor, q_critic: ActionTokenQCritic, replay_buffer: ReplayBuffer, batch_size: int = 256, beta: float = 1.0, device: str = 'cuda', n_tasks: int = 0)
DDPG-style actor update from the RL Token paper (Eq. 5):
L_π(θ) = E[ -Q_ψ(x, a) + β ‖a - ã‖² ]
where a ~ π_θ (stochastic rsample, paper Eq. 5). With fixed small std=0.1, the gradient direction is nearly identical to the deterministic mean, but we match the paper's formulation exactly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_tasks | int | if > 0, use per-task stratified sampling for balanced multi-task update. | 0 |
Returns:
| Name | Type | Description |
|---|---|---|
actor_loss | scalar tensor (with grad on actor params) | |
stats | dict |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
action_token_td_update ¶
action_token_td_update(actor: ActionTokenActor, critic, replay_buffer: ReplayBuffer, batch_size: int = 256, gamma: float = 0.99, device: str = 'cuda', target_critic=None, beta: float = 1.0, update_actor: bool = True, target_noise_std: float = 0.2, target_noise_clip: float = 0.5)
Combined TD3-style update step (backward compat wrapper).
If critic is ActionTokenQCritic: uses the new TD3/DDPG-style from paper. If critic is ActionTokenCritic (legacy V(s)): falls back to old logic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
beta | float | BC regularization coefficient (paper Eq. 5) | 1.0 |
update_actor | bool | If False, only update critic (TD3 delayed actor update) | True |
target_noise_std | float | Std of noise added to target policy actions | 0.2 |
target_noise_clip | float | Clip range for target policy noise | 0.5 |
Returns:
| Name | Type | Description |
|---|---|---|
loss | scalar tensor | |
stats | dict |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 | |
Fast rollout¶
action_token_rollout_fast ¶
Fast ActionToken rollout — step-lock architecture.
all envs move in lockstep.
- Batch VLA forward for ALL active envs (one GPU call)
- Batch encoder + actor
- ALL envs execute chunk in parallel threads
- Collect results, repeat
No BatchInferenceServer needed. No async queuing. No batch fragmentation.
Speedup: ~50x vs original (env creation + batch fragmentation eliminated).
action_token_collect_group_steplock ¶
action_token_collect_group_steplock(env_pool: PersistentEnvPool, frozen_vla, encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, critic: ActionTokenCritic, suite_name: str, task_id: int, n_initial_states: int, action_norm_stats: dict, max_steps: int, chunk_len: int, G: int = 64, seed: int = 42, num_steps_wait: int = 10, device: str = 'cuda', video_dir: Optional[str] = None, group_idx: int = 0, store_images: bool = False, group_size: int = 1, reward_coef: float = 1.0, actor_chunk_len: int = None, env_offset: int = 0, warmup_mode: bool = False) -> List[ActionTokenEpisode]
Collect G episodes using step-lock architecture.
All envs move in lockstep
- Batch VLA forward (one GPU call for ALL active envs)
- Batch encoder + actor (or skip actor if warmup_mode)
- All envs execute chunk in parallel
- Repeat
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env_offset | int | starting env index in the pool | 0 |
actor_chunk_len | int | if set, actor outputs shorter chunk than VLA | None |
warmup_mode | bool | if True, use VLA actions directly (skip actor). For buffer pre-fill. | False |
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_rollout_fast.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 | |
action_token_collect_multitask_steplock ¶
action_token_collect_multitask_steplock(env_pool: PersistentEnvPool, frozen_vla, encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, critic: ActionTokenCritic, suite_name: str, task_ids: List[int], n_initial_states: int, action_norm_stats: dict, max_steps: int, chunk_len: int, G_per_task: int = 8, seed: int = 42, num_steps_wait: int = 10, device: str = 'cuda', group_idx: int = 0, store_images: bool = False, group_size: int = 1, reward_coef: float = 1.0, actor_chunk_len: int = None, warmup_mode: bool = False) -> List[ActionTokenEpisode]
Collect episodes for MULTIPLE tasks on ONE GPU in a single step-lock loop.
All tasks' envs are merged into one batch for VLA forward — no per-task threading, no CUDA concurrency issues, maximum GPU batch utilization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
task_ids | List[int] | list of task IDs to run on this GPU | required |
G_per_task | int | episodes per task | 8 |
Returns: flat list of all episodes across all tasks
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_rollout_fast.py
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 | |
Shared components (common/)¶
Rollout¶
rollout ¶
Episode rollout for GRPO.
Key design
- Each episode stores trajectory as tensors: (traj_len, ...) for batched forward
- finish_step tracks actual episode length for masking
- Multiprocess env workers (one process per env)
- Gaussian policy: a ~ N(μ, σ²I), log_prob = -||a-μ||²/(2σ²)
StepRecord dataclass ¶
StepRecord(primary_image: ndarray, wrist_image: ndarray, instruction: str, norm_action: ndarray, old_log_prob: float, value: float = 0.0, action_token_ids: ndarray = None)
One inference step = one action-chunk prediction.
Replay buffer¶
replay_buffer ¶
Off-policy replay buffer for RLActionToken training (RL Token paper style).
Stores transitions as detached CPU tensors in a ring buffer. Each transition stores: - rl_token (state), action_taken, reward, next_rl_token (next state), done - vla_ref_action: the VLA reference action chunk (for BC regularization) - next_vla_action: VLA reference at next state (for target policy sampling) - prop_state / next_prop_state: proprioceptive state (eef_pos+axisangle+gripper, 8-dim) - task_id: integer task index for per-task stratified sampling
ReplayBuffer ¶
Fixed-capacity ring buffer for off-policy experience replay.
Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
push ¶
push(rl_token: Tensor, vla_action: Tensor, action_taken: Tensor, reward: float, next_rl_token: Tensor, next_vla_action: Tensor, done: bool, task_id: int = 0, prop_state: Optional[Tensor] = None, next_prop_state: Optional[Tensor] = None)
Store a single transition (all tensors detached to CPU).
Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
sample ¶
Sample a random mini-batch (uniform over all transitions).
Returns:
| Type | Description |
|---|---|
Tensor | Tuple of (rl_tokens, vla_actions, actions_taken, rewards, next_rl_tokens, next_vla_actions, dones, prop_states, next_prop_states) |
... | Each tensor has batch dim prepended and is moved to |
Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
sample_balanced ¶
Per-task stratified sampling: sample equal number of transitions from each task.
Ensures each task contributes equally to each gradient update, matching GRPO's equal-frequency multi-task update property.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | int | total transitions to sample | required |
n_tasks | int | number of tasks (0..n_tasks-1) | required |
device | str | target device | 'cuda' |
Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
is_ready ¶
task_counts ¶
Return number of transitions per task (for diagnostics).
Checkpoint I/O¶
ckpt_io ¶
Checkpoint save helper shared by all RLActionToken training phases.
Environments (envs/)¶
LIBERO environment¶
libero_env ¶
LIBERO environment proxy — runs in the VLA Python environment.
Spawns libero_env_worker.py as a subprocess using LIBERO_PYTHON (the separate conda env that has libero installed), then communicates via stdin/stdout with length-prefixed msgpack messages.
Usage matches the original direct API
env = LiberoEnv(suite_name, task_id, seed) obs = env.reset(initial_state_idx=0) obs, reward, done = env.step(action_7d) env.close()
LiberoEnv ¶
Proxy to a LIBERO environment running in a separate Python process.
The worker process is started once per LiberoEnv instance and reused across reset() calls (different tasks can be loaded with reset).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
libero_python | Optional[str] | Path to the LIBERO conda env Python binary. Defaults to LIBERO_PYTHON env var, then 'python'. | None |
Source code in AlphaBrain/training/reinforcement_learning/envs/libero_env.py
reset ¶
Reset the environment to a specific task and initial state.
Returns obs dict
- "primary_image" : PIL.Image
- "wrist_image" : PIL.Image
- "state" : np.ndarray (8,)
Source code in AlphaBrain/training/reinforcement_learning/envs/libero_env.py
step ¶
Execute one env step.
Returns:
| Name | Type | Description |
|---|---|---|
obs_dict | dict | parsed observation |
reward | float | 0.0 / 1.0 |
done | bool | episode termination flag |
Source code in AlphaBrain/training/reinforcement_learning/envs/libero_env.py
get_suite_info ¶
Query task count and task names from the LIBERO worker without opening an environment.
Returns: {"n_tasks": int, "task_names": [str, ...]}
Source code in AlphaBrain/training/reinforcement_learning/envs/libero_env.py
LIBERO environment workers¶
libero_env_worker ¶
LIBERO environment worker — runs inside the LIBERO Python environment.
Launched as a subprocess by libero_env.py (VLA Python env). Communicates via stdin/stdout using length-prefixed msgpack messages.
Protocol (both directions): [4-byte little-endian uint32 length][msgpack payload]
Commands received (from parent): {"cmd": "reset", "task_suite": str, "task_id": int, "initial_state_idx": int, "seed": int} {"cmd": "step", "action": [7 floats]}
Responses sent (to parent): {"status": "ok", "obs": {"primary":
libero_env_worker_fast ¶
Fast LIBERO environment worker — socket pair IPC, MuJoCo env reuse.
Launched by persistent_env_pool.py. Receives socket FD as argv[1]. Communicates via socket (not stdin/stdout) — eliminates pipe buffer deadlock.
Key optimization: same task → only reset() + set_init_state() (no env recreation).
Persistent environment pool¶
persistent_env_pool ¶
Persistent LiberoEnv pool — keeps subprocess envs alive across rollout iterations.
IPC: socket pair (bidirectional, with settimeout) instead of stdin/stdout pipes. This eliminates the pipe buffer deadlock that occurs with high-concurrency pipe I/O.
Two-layer optimization
- Subprocess pool: LiberoEnv subprocesses created once, reused across iterations
- Fast worker: libero_env_worker_fast.py reuses MuJoCo env for same task
PersistentEnvPool ¶
PersistentEnvPool(num_envs: int, libero_python: Optional[str] = None, egl_gpu_id: Optional[int] = None)
Pool of persistent fast LiberoEnv subprocess workers.
Each worker subprocess stays alive for the entire training run. On reset with same task: only reset() + set_init_state() in worker (fast). On reset with different task: worker recreates MuJoCo env (slower, but rare).
Source code in AlphaBrain/training/reinforcement_learning/envs/persistent_env_pool.py
Evaluation (eval/)¶
eval_libero ¶
Standalone offline eval for an RLActionToken iter checkpoint on LIBERO.
Loads
- frozen QwenOFT VLA from --vla_ckpt
- ActionTokenEncoderDecoder from
/encoder.pt - ActionTokenActor from
/actor.pt
Runs deterministic eval across all (or selected) tasks of the suite, prints per-task SR, and optionally appends the result to a JSON file.
The eval protocol is shared with training via AlphaBrain.training.reinforcement_learning.eval.eval_helpers._eval_deterministic_local.
eval_helpers ¶
Deterministic eval helpers shared by training loops and standalone eval.
aggregate_shards ¶
Aggregate per-shard eval JSONs into a single summary.
Each shard JSON (produced by eval_libero.py --results_json shard_X.json) contains a per_task_sr dict for the subset of tasks that shard ran. This script merges them into one summary with all per-task SRs and the overall SR (mean across tasks).
Usage
python AlphaBrain/training/reinforcement_learning/eval/aggregate_shards.py \ --out_dir
Training entrypoints (trainers/)¶
Shared CLI arguments¶
train_args ¶
CLI args for RLActionToken training (all three phases).
Main entrypoint¶
train ¶
ActionToken training entry for QwenOFT on LIBERO.
Three phases
--phase pretrain Encoder-decoder pretraining via reconstruction loss --phase rl On-policy multi-GPU PPO update (legacy; PPO/GRPO is a TODO) --phase rl_offpolicy Off-policy TD3 with split rollout/training GPUs (production)
Usage
Phase 1: Encoder pretraining¶
python AlphaBrain/training/reinforcement_learning/trainers/train.py --phase pretrain --ckpt_path results/training/my_sft/final_model --suite libero_goal --task_id 0
Phase 2 (production): off-policy TD3¶
python AlphaBrain/training/reinforcement_learning/trainers/train.py --phase rl_offpolicy --ckpt_path results/training/my_sft/final_model --encoder_path results/action_token_training/pretrain/checkpoints/pretrain_best/encoder.pt --suite libero_goal --task_id 0
Pretrain¶
train_pretrain ¶
Phase 1: encoder-decoder pretraining via reconstruction loss.
On-policy RL¶
train_rl_onpolicy ¶
Phase 2 (legacy on-policy variant): multi-GPU rollout + PPO update.
NOTE: this is the legacy on-policy training path. The off-policy TD3 variant (train_rl_offpolicy.run_rl_offpolicy) is the production code path used by every release run; this file is kept for reference.
TODO: implement proper PPO / GRPO updates here. The current action_token_ppo_loss is a placeholder that mixes value-loss + clipped surrogate; before relying on this phase for new experiments, port a battle-tested PPO loop (importance sampling, KL early stop, value clipping, gradient accumulation, group-relative normalization for GRPO, etc.) from a reference implementation.
run_rl ¶
Phase 2 on-policy: multi-GPU parallel rollout + PPO update.
Each GPU loads a frozen VLA copy and runs its own env workers to collect episodes in parallel. All episodes are gathered, then the tiny RLActionToken network update happens on every rank (identical, since network is tiny).
6 GPUs = 6× rollout throughput (the actual bottleneck is CPU env.step).
Source code in AlphaBrain/training/reinforcement_learning/trainers/train_rl_onpolicy.py
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 | |
Off-policy RL¶
train_rl_offpolicy ¶
Phase 2 (production): off-policy TD3 with split rollout/training GPUs.
Architecture (single process, no Accelerate): - Rollout GPUs: each loads a frozen VLA copy, collects episodes in parallel - Train GPU: runs actor-critic TD updates (backward pass) - Replay buffer: centralized on CPU, fed by all rollout GPUs - Eval: distributed across rollout GPUs for speed
run_rl_offpolicy ¶
Phase 2 off-policy variant: split rollout/training GPUs.
Lets you freely assign GPUs, e.g.: --rollout_gpus 0,1,2,3,4 --train_gpu 5 Train GPU can overlap with rollout GPUs (rollout and training are sequential).
Source code in AlphaBrain/training/reinforcement_learning/trainers/train_rl_offpolicy.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 | |