Skip to content

Training › Reinforcement Learning

Source path: AlphaBrain/training/reinforcement_learning/

Full implementation of VLA online RL training (RLActionToken — RL Token). Paper: RL Token: Bootstrapping Online RL with VLA Models (Physical Intelligence).

Layout:

  • algos/RLActionToken/ — encoder/decoder, actor-critic, trainer, fast rollout
  • common/ — rollout, replay buffer, checkpoint I/O
  • envs/ — LIBERO environment, persistent env pool, env workers
  • eval/ — LIBERO evaluation and shard aggregation
  • trainers/ — on-policy / off-policy / pretrain entrypoints

Top-level re-exports

reinforcement_learning


RLActionToken algorithm

Encoder / Decoder

action_token_encoder_decoder

ActionToken Encoder-Decoder: Information bottleneck between frozen VLA and small RL network.

Inspired by the "RL Token" paper (Physical Intelligence, 2026), but with deviations from the paper's construction: - Encoder input is the VLA's action-query hidden states (M × H) gathered at the action-token positions, not the full image-token sequence (N × H) as in the paper's Fig. 2. - An extra Linear(H → bottleneck_dim) projection compresses per-token dim (e.g. 2048 → 256); the paper keeps the RL token at the VLA hidden dim. - The decoder is a self-attention transformer with a causal mask and a prefix token, not the encoder-decoder cross-attention structure of the paper's Eq. 2. A faithful paper-accurate reimplementation is still under test.

Paper Eq. 1 — Encoder: z_rl = g_φ([z_{1:M}, e_rl])_{M+1} Append learnable embedding e_rl to VLA token sequence, run through self-attention encoder transformer, take the e_rl position output.

Paper Eq. 2 — Decoder (autoregressive reconstruction): L_ro = E[ Σ_{i=1}^{M} ‖h_φ(d_φ([z_rl, sg(z_{1:i-1})]))_i − sg(z_i)‖² ] Reconstruct VLA tokens autoregressively from z_rl to enforce information preservation in the bottleneck.

ActionTokenEncoder
ActionTokenEncoder(input_dim: int = 2048, bottleneck_dim: int = 256, num_heads: int = 4, num_layers: int = 2, dropout: float = 0.0)

Bases: Module

Paper Eq. 1: Compress VLA action_queries (B, M, H) → rl_token (B, 1, D).

Appends a learnable e_rl to the token sequence and processes with self-attention (TransformerEncoderLayer). The output at the e_rl position is projected to the bottleneck dimension.

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
def __init__(
    self,
    input_dim: int = 2048,
    bottleneck_dim: int = 256,
    num_heads: int = 4,
    num_layers: int = 2,
    dropout: float = 0.0,
):
    super().__init__()
    self.input_dim = input_dim
    self.bottleneck_dim = bottleneck_dim

    # Learnable RL embedding e_rl (appended to token sequence)
    self.cls_token = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02)

    # Self-attention encoder layers (paper: g_φ processes [z_{1:M}, e_rl])
    self.self_attn_layers = nn.ModuleList([
        nn.TransformerEncoderLayer(
            d_model=input_dim,
            nhead=num_heads,
            dim_feedforward=input_dim * 2,
            dropout=dropout,
            batch_first=True,
        )
        for _ in range(num_layers)
    ])
    self.bottleneck_proj = nn.Linear(input_dim, bottleneck_dim)
forward
forward(action_queries: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
action_queries Tensor

(B, M, H) from frozen VLA

required

Returns: rl_token: (B, 1, D_bottleneck)

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
def forward(self, action_queries: torch.Tensor) -> torch.Tensor:
    """
    Args:
        action_queries: (B, M, H) from frozen VLA
    Returns:
        rl_token: (B, 1, D_bottleneck)
    """
    action_queries = action_queries.float()
    B = action_queries.size(0)
    cls = self.cls_token.expand(B, -1, -1)               # (B, 1, H)
    seq = torch.cat([action_queries, cls], dim=1)         # (B, M+1, H)
    for layer in self.self_attn_layers:
        seq = layer(seq)                                  # (B, M+1, H)
    rl_token = self.bottleneck_proj(seq[:, -1:, :])       # (B, 1, D)
    return rl_token
ActionTokenDecoder
ActionTokenDecoder(bottleneck_dim: int = 256, output_dim: int = 2048, chunk_len: int = 8, num_heads: int = 4, num_layers: int = 2, dropout: float = 0.0)

Bases: Module

Paper Eq. 2: Autoregressive reconstruction of VLA tokens from z_rl.

L_ro = E[ Σ_i ‖h_φ(d_φ([z_rl, sg(z_{1:i-1})]))_i − sg(z_i)‖² ]

The decoder takes z_rl as prefix, and autoregressively reconstructs each VLA token conditioned on z_rl and previously reconstructed tokens. A causal mask ensures position i can only attend to positions < i (plus the z_rl prefix which is always visible).

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
def __init__(
    self,
    bottleneck_dim: int = 256,
    output_dim: int = 2048,
    chunk_len: int = 8,
    num_heads: int = 4,
    num_layers: int = 2,
    dropout: float = 0.0,
):
    super().__init__()
    self.chunk_len = chunk_len
    self.output_dim = output_dim
    self.expand_proj = nn.Linear(bottleneck_dim, output_dim)
    self.pos_embed = nn.Parameter(torch.randn(1, chunk_len, output_dim) * 0.02)

    # Self-attention decoder layers with causal masking
    self.self_attn_layers = nn.ModuleList([
        nn.TransformerEncoderLayer(
            d_model=output_dim,
            nhead=num_heads,
            dim_feedforward=output_dim * 2,
            dropout=dropout,
            batch_first=True,
        )
        for _ in range(num_layers)
    ])
forward
forward(rl_token: Tensor, target_tokens: Tensor = None) -> torch.Tensor

Parameters:

Name Type Description Default
rl_token Tensor

(B, 1, D_bottleneck)

required
target_tokens Tensor

(B, M, H) stop-gradient VLA tokens for teacher forcing. If None, uses learned positional embeddings (inference mode).

None

Returns: reconstructed: (B, M, H)

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
def forward(
    self,
    rl_token: torch.Tensor,
    target_tokens: torch.Tensor = None,
) -> torch.Tensor:
    """
    Args:
        rl_token: (B, 1, D_bottleneck)
        target_tokens: (B, M, H) stop-gradient VLA tokens for teacher forcing.
                       If None, uses learned positional embeddings (inference mode).
    Returns:
        reconstructed: (B, M, H)
    """
    B = rl_token.size(0)
    prefix = self.expand_proj(rl_token)                   # (B, 1, H)

    if target_tokens is not None:
        # Training: teacher forcing with stop-gradient targets
        # Sequence: [z_rl, sg(z_1), sg(z_2), ..., sg(z_{M-1})]
        # Target:   [z_1,  z_2,     z_3,     ...,  z_M        ]
        # Shifted input: z_rl is position 0, z_1 is position 1, etc.
        shifted_input = target_tokens[:, :-1, :].detach()  # (B, M-1, H)
        seq = torch.cat([prefix, shifted_input], dim=1)    # (B, M, H)
        seq = seq + self.pos_embed                         # add positional info
    else:
        # Inference: use positional embeddings (no teacher forcing)
        seq = prefix.expand(-1, self.chunk_len, -1) + self.pos_embed  # (B, M, H)

    # Causal mask: position i can only attend to positions <= i
    # This ensures autoregressive structure
    M = seq.size(1)
    causal_mask = torch.triu(
        torch.ones(M, M, device=seq.device, dtype=torch.bool), diagonal=1
    )  # True = masked out

    for layer in self.self_attn_layers:
        seq = layer(seq, src_mask=causal_mask, is_causal=True)

    return seq  # (B, M, H) — each position predicts the corresponding target
ActionTokenEncoderDecoder
ActionTokenEncoderDecoder(input_dim: int = 2048, bottleneck_dim: int = 256, chunk_len: int = 8, num_heads: int = 4, encoder_layers: int = 2, decoder_layers: int = 2, dropout: float = 0.0)

Bases: Module

Combined Encoder-Decoder for ActionToken pretraining.

Training: autoregressive reconstruction with teacher forcing. Inference: encoder only (decoder not used during RL).

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
def __init__(
    self,
    input_dim: int = 2048,
    bottleneck_dim: int = 256,
    chunk_len: int = 8,
    num_heads: int = 4,
    encoder_layers: int = 2,
    decoder_layers: int = 2,
    dropout: float = 0.0,
):
    super().__init__()
    self.encoder = ActionTokenEncoder(
        input_dim=input_dim,
        bottleneck_dim=bottleneck_dim,
        num_heads=num_heads,
        num_layers=encoder_layers,
        dropout=dropout,
    )
    self.decoder = ActionTokenDecoder(
        bottleneck_dim=bottleneck_dim,
        output_dim=input_dim,
        chunk_len=chunk_len,
        num_heads=num_heads,
        num_layers=decoder_layers,
        dropout=dropout,
    )
forward
forward(action_queries: Tensor)

Full encode-decode pass with autoregressive reconstruction loss.

Paper Eq. 2: L_ro = E[ Σ_i ‖reconstructed_i − sg(z_i)‖² ]

Returns:

Name Type Description
rl_token

(B, 1, D)

recon_loss

scalar MSE reconstruction loss

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_encoder_decoder.py
def forward(self, action_queries: torch.Tensor):
    """
    Full encode-decode pass with autoregressive reconstruction loss.

    Paper Eq. 2:
      L_ro = E[ Σ_i ‖reconstructed_i − sg(z_i)‖² ]

    Returns:
        rl_token: (B, 1, D)
        recon_loss: scalar MSE reconstruction loss
    """
    action_queries = action_queries.float()
    rl_token = self.encoder(action_queries)
    # Autoregressive decode with teacher forcing
    reconstructed = self.decoder(rl_token, target_tokens=action_queries.detach())
    recon_loss = F.mse_loss(reconstructed, action_queries.detach())
    return rl_token, recon_loss

Actor / Critic

action_token_actor_critic

ActionToken Actor-Critic, following the RL Token paper (Physical Intelligence) closely on the actor/critic side (the deviations from the paper live in action_token_encoder_decoder.py and action_token_trainer.py).

Key design choices from the paper
  • Actor (Eq. 4): π_θ(a | x, ã) = N(μ_θ(x, ã), σ²I) The actor DIRECTLY outputs the action chunk, conditioned on (rl_token, vla_ref). VLA reference is an INPUT to the network, NOT a structural residual. BC regularization β‖a - ã‖² in the LOSS keeps actions close to VLA. Reference-action dropout (50%) prevents identity collapse.
  • Critic (Eq. 3): Q(s, a) — twin Q-networks (TD3-style).
ActionTokenActor
ActionTokenActor(bottleneck_dim: int = 256, action_dim: int = 7, chunk_len: int = 8, hidden_dim: int = 256, ref_dropout: float = 0.5, fixed_std: float = 0.1, prop_dim: int = 0)

Bases: Module

ActionToken actor from paper (Eq. 4-5).

π_θ(a_{1:C} | x, ã_{1:C}) = N(μ_θ(x, ã_{1:C}), σ²I)

The network takes (rl_token, vla_reference_action) as input and DIRECTLY outputs the full action chunk. The VLA reference is just a conditioning signal — the BC regularization in the loss (not in the architecture) keeps the output close to VLA.

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
def __init__(
    self,
    bottleneck_dim: int = 256,
    action_dim: int = 7,
    chunk_len: int = 8,
    hidden_dim: int = 256,    # paper: 256 for most tasks, 512 for hard
    ref_dropout: float = 0.5,  # paper: 50%
    fixed_std: float = 0.1,    # paper: small fixed std
    prop_dim: int = 0,         # proprioceptive state dim (paper: eef_pos+axisangle+gripper=8)
):
    super().__init__()
    self.action_dim = action_dim
    self.chunk_len = chunk_len
    self.ref_dropout = ref_dropout
    self.prop_dim = prop_dim

    flat_action_dim = action_dim * chunk_len
    input_dim = bottleneck_dim + prop_dim + flat_action_dim

    # Paper Appendix B: two-layer MLP (256 hidden) for most tasks,
    # three-layer MLP (512 hidden) for screw task
    self.net = nn.Sequential(
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, flat_action_dim),
    )

    # Kaiming init for hidden layers (default), small normal for output layer.
    # Paper: actor directly outputs actions; BC regularization in the loss
    # (not architecture) keeps output close to VLA reference.
    nn.init.normal_(self.net[-1].weight, std=0.01)
    nn.init.zeros_(self.net[-1].bias)

    # Paper: small fixed standard deviation
    self.register_buffer("fixed_std", torch.tensor(fixed_std))
forward
forward(rl_token: Tensor, vla_action: Tensor, prop_state: Tensor = None, deterministic: bool = False)

Parameters:

Name Type Description Default
rl_token Tensor

(B, 1, D) or (B, D)

required
vla_action Tensor

(B, chunk_len, action_dim) — VLA reference

required
prop_state Tensor

(B, prop_dim) — proprioceptive state (eef_pos+axisangle+gripper)

None

Returns: action: (B, chunk_len, action_dim) log_prob: (B,) or None if deterministic

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
def forward(
    self,
    rl_token: torch.Tensor,
    vla_action: torch.Tensor,
    prop_state: torch.Tensor = None,
    deterministic: bool = False,
):
    """
    Args:
        rl_token: (B, 1, D) or (B, D)
        vla_action: (B, chunk_len, action_dim) — VLA reference
        prop_state: (B, prop_dim) — proprioceptive state (eef_pos+axisangle+gripper)
    Returns:
        action: (B, chunk_len, action_dim)
        log_prob: (B,) or None if deterministic
    """
    mean = self._get_mean(rl_token, vla_action, prop_state,
                          apply_dropout=(self.training and not deterministic))

    if deterministic:
        return mean, None

    std = self.fixed_std.expand_as(mean)
    dist = torch.distributions.Normal(mean, std)
    action = dist.rsample()
    log_prob = dist.log_prob(action).sum(dim=(-2, -1))  # (B,)
    return action, log_prob
log_prob_of
log_prob_of(rl_token: Tensor, vla_action: Tensor, taken_action: Tensor, prop_state: Tensor = None) -> torch.Tensor

Compute log_prob of a previously taken action under current policy.

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
def log_prob_of(
    self,
    rl_token: torch.Tensor,
    vla_action: torch.Tensor,
    taken_action: torch.Tensor,
    prop_state: torch.Tensor = None,
) -> torch.Tensor:
    """Compute log_prob of a previously taken action under current policy."""
    mean = self._get_mean(rl_token, vla_action, prop_state, apply_dropout=False)
    std = self.fixed_std.expand_as(mean)
    dist = torch.distributions.Normal(mean, std)
    return dist.log_prob(taken_action).sum(dim=(-2, -1))  # (B,)
ActionTokenQCritic
ActionTokenQCritic(bottleneck_dim: int = 256, action_dim: int = 7, chunk_len: int = 8, hidden_dim: int = 256, prop_dim: int = 0)

Bases: Module

Twin Q-critic from RL Token paper (Eq. 3, following TD3).

Q_ψ(x, a_{1:C}) takes the RL token state AND the action chunk as input. Contains two independent Q-networks; use min(Q1, Q2) for target values.

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
def __init__(
    self,
    bottleneck_dim: int = 256,
    action_dim: int = 7,
    chunk_len: int = 8,
    hidden_dim: int = 256,
    prop_dim: int = 0,  # proprioceptive state dim
):
    super().__init__()
    self.action_dim = action_dim
    self.chunk_len = chunk_len
    self.prop_dim = prop_dim

    flat_action_dim = action_dim * chunk_len
    input_dim = bottleneck_dim + prop_dim + flat_action_dim

    # Twin Q-networks (TD3 style)
    self.q1 = nn.Sequential(
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, 1),
    )
    self.q2 = nn.Sequential(
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, 1),
    )
forward
forward(rl_token: Tensor, action: Tensor, prop_state: Tensor = None) -> tuple

Returns: q1: (B,), q2: (B,)

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
def forward(
    self,
    rl_token: torch.Tensor,
    action: torch.Tensor,
    prop_state: torch.Tensor = None,
) -> tuple:
    """
    Returns: q1: (B,), q2: (B,)
    """
    if rl_token.dim() == 3:
        rl_token = rl_token.squeeze(1)
    B = rl_token.size(0)
    action_flat = action.reshape(B, -1)
    if self.prop_dim > 0:
        if prop_state is None:
            prop_state = torch.zeros(B, self.prop_dim, device=rl_token.device,
                                     dtype=rl_token.dtype)
        x = torch.cat([rl_token, prop_state, action_flat], dim=-1)
    else:
        x = torch.cat([rl_token, action_flat], dim=-1)
    return self.q1(x).squeeze(-1), self.q2(x).squeeze(-1)
q1_forward
q1_forward(rl_token: Tensor, action: Tensor, prop_state: Tensor = None) -> torch.Tensor

Single Q1 forward (used for actor loss to save compute).

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
def q1_forward(
    self,
    rl_token: torch.Tensor,
    action: torch.Tensor,
    prop_state: torch.Tensor = None,
) -> torch.Tensor:
    """Single Q1 forward (used for actor loss to save compute)."""
    if rl_token.dim() == 3:
        rl_token = rl_token.squeeze(1)
    B = rl_token.size(0)
    action_flat = action.reshape(B, -1)
    if self.prop_dim > 0:
        if prop_state is None:
            prop_state = torch.zeros(B, self.prop_dim, device=rl_token.device,
                                     dtype=rl_token.dtype)
        x = torch.cat([rl_token, prop_state, action_flat], dim=-1)
    else:
        x = torch.cat([rl_token, action_flat], dim=-1)
    return self.q1(x).squeeze(-1)
ActionTokenCritic
ActionTokenCritic(bottleneck_dim: int = 256, hidden_dim: int = 512)

Bases: Module

State value estimator V(s) from rl_token. (Legacy, for PPO path only.)

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
def __init__(
    self,
    bottleneck_dim: int = 256,
    hidden_dim: int = 512,
):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(bottleneck_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, 1),
    )
soft_update_target
soft_update_target(source: Module, target: Module, tau: float = 0.005)

Polyak averaging: target = (1 - tau) * target + tau * source.

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_actor_critic.py
def soft_update_target(source: nn.Module, target: nn.Module, tau: float = 0.005):
    """Polyak averaging: target = (1 - tau) * target + tau * source."""
    with torch.no_grad():
        for sp, tp in zip(source.parameters(), target.parameters()):
            tp.data.mul_(1.0 - tau).add_(sp.data, alpha=tau)

Trainer (loss / update)

action_token_trainer

ActionToken Trainer: Two-phase training for the RLActionToken variant.

Phase 1 — Encoder Pretraining: Freeze VLA, train encoder-decoder via reconstruction loss on rollout data.

Phase 2 — Actor-Critic RL: Freeze VLA, use pretrained encoder. Rollout with actor (action editing), update actor + critic via PPO-clip. Optional encoder fine-tuning with reconstruction regularization.

BatchInferenceServer
BatchInferenceServer(frozen_vla, encoder, actor, critic, device: str, max_batch_size: int = 64, batch_timeout_s: float = 0.005, actor_chunk_len: int = None)

Batched VLA + encoder + actor + critic inference for rollout.

Problem with model_lock (old approach): - Each env thread acquires lock → VLA forward batch=1 → releases lock - GPU processes one observation at a time regardless of how many envs are running - With chunk_len=8, GPU is idle 7/8 of the time (CPU env.step dominates)

Solution
  • All env threads submit (img_pair, instruction) to a request queue and block
  • A single background thread drains the queue and runs ONE batched GPU forward
  • All waiting threads get results simultaneously → GPU utilization scales with num_envs
Cross-task batching
  • In multi-task mode, multiple tasks on the same GPU share ONE server
  • When task-0 envs and task-5 envs both need VLA inference, they batch together
  • Effective batch size = n_tasks_per_gpu × n_envs_per_task
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def __init__(
    self,
    frozen_vla,
    encoder,
    actor,
    critic,
    device: str,
    max_batch_size: int = 64,
    batch_timeout_s: float = 0.005,
    actor_chunk_len: int = None,
):
    self.frozen_vla = frozen_vla
    self.encoder = encoder
    self.actor = actor
    self.critic = critic
    self.device = device
    self.max_batch_size = max_batch_size
    self.batch_timeout_s = batch_timeout_s
    # If actor uses shorter chunk than VLA, slice vla_actions accordingly
    self.actor_chunk_len = actor_chunk_len

    self._q: queue.Queue = queue.Queue()
    self._stop = threading.Event()
    self._thread = threading.Thread(target=self._loop, daemon=True)
    self.warmup_mode = False   # when True, return pure VLA actions (no actor)
infer
infer(img_pair: list, instruction: str, prop_state=None)

Called from env threads. Blocks until the batch result is ready.

Parameters:

Name Type Description Default
img_pair list

[primary_img, wrist_img] (numpy arrays, already flipped)

required
instruction str

task language description

required
prop_state

np.ndarray or torch.Tensor (prop_dim,) — proprioceptive state

None

Returns:

Type Description

(rl_token_cpu, vla_action_cpu, action_cpu, log_prob_float, value_float)

tensors are (1, ...) on CPU.

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def infer(self, img_pair: list, instruction: str,
          prop_state=None):
    """
    Called from env threads. Blocks until the batch result is ready.

    Args:
        img_pair: [primary_img, wrist_img]  (numpy arrays, already flipped)
        instruction: task language description
        prop_state: np.ndarray or torch.Tensor (prop_dim,) — proprioceptive state

    Returns:
        (rl_token_cpu, vla_action_cpu, action_cpu, log_prob_float, value_float)
        tensors are (1, ...) on CPU.
    """
    done = threading.Event()
    box: list = [None]
    self._q.put((img_pair, instruction, prop_state, done, box))
    done.wait()
    return box[0]
ActionTokenStepRecord dataclass
ActionTokenStepRecord(rl_token: Tensor, vla_action: Tensor, action_taken: Tensor, old_log_prob: float, value: float = 0.0, prop_state: Optional[Tensor] = None, sub_tokens: list = list(), images: Optional[list] = None, instruction: Optional[str] = None)

One inference step during RLActionToken rollout.

pretrain_encoder_step
pretrain_encoder_step(frozen_vla, enc_dec: ActionTokenEncoderDecoder, batch_images: list, instructions: list, device: str = 'cuda')

One pretraining step: VLA forward → encoder → decoder → reconstruction loss.

Returns:

Name Type Description
recon_loss

scalar tensor (with grad on enc_dec params only)

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def pretrain_encoder_step(
    frozen_vla,
    enc_dec: ActionTokenEncoderDecoder,
    batch_images: list,
    instructions: list,
    device: str = "cuda",
):
    """
    One pretraining step: VLA forward → encoder → decoder → reconstruction loss.

    Returns:
        recon_loss: scalar tensor (with grad on enc_dec params only)
    """
    with torch.no_grad():
        action_queries = frozen_vla.get_action_queries(
            batch_images=batch_images,
            instructions=instructions,
        )  # (B, chunk_len, H) on device

    _, recon_loss = enc_dec(action_queries)
    return recon_loss
collect_observations_fast
collect_observations_fast(suite_name: str, task_id: int, n_observations: int, steps_per_env: int = 20, num_envs: int = 4, n_initial_states: int = 50, libero_python: str = None, seed: int = 42) -> list

Lightweight observation collection for encoder pretraining.

Instead of running full episodes (300+ steps each, 99% CPU idle), just reset envs to diverse initial states and take a few random steps. Returns list of (images, instruction) tuples ready for VLA forward.

Much faster than collect_group: no VLA forward during collection, no action caching, just random exploration for observation diversity.

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def collect_observations_fast(
    suite_name: str,
    task_id: int,
    n_observations: int,
    steps_per_env: int = 20,
    num_envs: int = 4,
    n_initial_states: int = 50,
    libero_python: str = None,
    seed: int = 42,
) -> list:
    """
    Lightweight observation collection for encoder pretraining.

    Instead of running full episodes (300+ steps each, 99% CPU idle),
    just reset envs to diverse initial states and take a few random steps.
    Returns list of (images, instruction) tuples ready for VLA forward.

    Much faster than collect_group: no VLA forward during collection,
    no action caching, just random exploration for observation diversity.
    """
    import threading
    from concurrent.futures import ThreadPoolExecutor, as_completed

    observations = []

    # Shared progress counter across threads
    obs_per_reset = 1 + steps_per_env
    n_resets = max(1, (n_observations + obs_per_reset - 1) // obs_per_reset)
    resets_per_env = (n_resets + num_envs - 1) // num_envs
    progress_lock = threading.Lock()
    progress = {"resets_done": 0, "obs_count": 0}
    log_interval = max(1, n_resets // 20)  # log every ~5%

    logger.info(f"  [collect_obs] Plan: {n_resets} env resets × {steps_per_env} steps/reset "
                f"= ~{n_resets * obs_per_reset} obs, using {num_envs} envs")

    def _collect_from_env(env_idx, states_to_visit):
        local_rng = np.random.RandomState(seed + env_idx * 10000)
        local_obs = []
        env = LiberoEnv(libero_python=libero_python)
        try:
            for s_idx in states_to_visit:
                obs = env.reset(
                    suite_name=suite_name,
                    task_id=task_id,
                    initial_state_idx=s_idx % n_initial_states,
                    seed=seed + env_idx * 1000 + s_idx,
                )
                task_desc = env.task_description
                local_obs.append((
                    [obs["primary_image"].copy(), obs["wrist_image"].copy()],
                    task_desc,
                ))
                for _ in range(steps_per_env):
                    random_action = local_rng.uniform(-1, 1, size=7).astype(np.float32)
                    random_action[6] = local_rng.choice([-1.0, 1.0])
                    obs, _, done = env.step(random_action)
                    local_obs.append((
                        [obs["primary_image"].copy(), obs["wrist_image"].copy()],
                        task_desc,
                    ))
                    if done:
                        break
                # Update shared progress
                with progress_lock:
                    progress["resets_done"] += 1
                    rd = progress["resets_done"]
                    if rd % log_interval == 0 or rd == n_resets:
                        est_obs = rd * obs_per_reset
                        logger.info(f"  [collect_obs] reset {rd}/{n_resets} "
                                    f"({rd * 100 // n_resets}%), ~{est_obs} obs collected")
        finally:
            env.close()
        return local_obs

    with ThreadPoolExecutor(max_workers=num_envs) as pool:
        futures = {}
        for e in range(num_envs):
            start_state = e * resets_per_env
            end_state = min(start_state + resets_per_env, n_resets)
            if start_state >= n_resets:
                break
            states = list(range(start_state, end_state))
            futures[pool.submit(_collect_from_env, e, states)] = e

        for fut in as_completed(futures):
            local_obs = fut.result()
            observations.extend(local_obs)

    # Trim to requested size
    if len(observations) > n_observations:
        np.random.shuffle(observations)
        observations = observations[:n_observations]

    logger.info(f"  [collect_obs] Collected {len(observations)} observations")
    return observations
extract_action_queries_from_obs
extract_action_queries_from_obs(frozen_vla, observations: list, batch_size: int = 16, device: str = 'cuda') -> torch.Tensor

Batch-extract action_queries from observations via frozen VLA.

Parameters:

Name Type Description Default
observations list

list of (images, instruction) where images=[primary, wrist]

required

Returns:

Name Type Description
all_queries Tensor

(N, chunk_len, H) tensor on device

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def extract_action_queries_from_obs(
    frozen_vla,
    observations: list,
    batch_size: int = 16,
    device: str = "cuda",
) -> torch.Tensor:
    """
    Batch-extract action_queries from observations via frozen VLA.

    Args:
        observations: list of (images, instruction) where images=[primary, wrist]

    Returns:
        all_queries: (N, chunk_len, H) tensor on device
    """
    N = len(observations)
    n_batches = (N + batch_size - 1) // batch_size
    all_queries = []
    for b_idx, start in enumerate(range(0, N, batch_size)):
        end = min(start + batch_size, N)
        batch_imgs = [observations[i][0] for i in range(start, end)]
        batch_instr = [observations[i][1] for i in range(start, end)]
        with torch.no_grad():
            aq = frozen_vla.get_action_queries(
                batch_images=batch_imgs,
                instructions=batch_instr,
            )  # (B, chunk_len, H)
        all_queries.append(aq)
        if (b_idx + 1) % max(1, n_batches // 10) == 0 or b_idx == n_batches - 1:
            logger.info(f"  [extract] batch {b_idx + 1}/{n_batches} "
                        f"({end}/{N} samples)")

    return torch.cat(all_queries, dim=0)  # (N, chunk_len, H)
extract_action_queries_dataset
extract_action_queries_dataset(frozen_vla, episodes: list, batch_size: int = 16, device: str = 'cuda') -> torch.Tensor

Batch-extract all action_queries from rollout episodes via frozen VLA. (Legacy: used when full episodes are already collected.)

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def extract_action_queries_dataset(
    frozen_vla,
    episodes: list,
    batch_size: int = 16,
    device: str = "cuda",
) -> torch.Tensor:
    """
    Batch-extract all action_queries from rollout episodes via frozen VLA.
    (Legacy: used when full episodes are already collected.)
    """
    obs_list = []
    for ep in episodes:
        for step in ep.step_records:
            obs_list.append(
                ([step.primary_image, step.wrist_image], step.instruction)
            )
    return extract_action_queries_from_obs(frozen_vla, obs_list, batch_size, device)
action_token_collect_group
action_token_collect_group(frozen_vla, encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, critic: ActionTokenCritic, suite_name: str, task_id: int, n_initial_states: int, action_norm_stats: dict, max_steps: int, chunk_len: int, G: int = 8, libero_python: Optional[str] = None, seed: int = 42, num_steps_wait: int = 10, device: str = 'cuda', video_dir: Optional[str] = None, num_envs: int = 4, group_idx: int = 0, batch_server: Optional[BatchInferenceServer] = None, store_images: bool = False, group_size: int = 1, reward_coef: float = 1.0) -> List[ActionTokenEpisode]

Collect G episodes using RLActionToken policy.

Uses BatchInferenceServer for GPU inference: all num_envs env threads submit requests concurrently; a single background thread batches them into one GPU forward pass. This maximizes GPU utilization vs. the old model_lock (batch=1).

Parameters:

Name Type Description Default
batch_server Optional[BatchInferenceServer]

shared server for this GPU (created by caller for cross-task batching). If None, creates a local server for this call only.

None
group_size int

number of trajectories per initial state. G episodes are split into G//group_size unique states, each repeated group_size times. Default 1 = legacy behavior (no repeat).

1
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
@torch.no_grad()
def action_token_collect_group(
    frozen_vla,
    encoder: ActionTokenEncoderDecoder,
    actor: ActionTokenActor,
    critic: ActionTokenCritic,
    suite_name: str,
    task_id: int,
    n_initial_states: int,
    action_norm_stats: dict,
    max_steps: int,
    chunk_len: int,
    G: int = 8,
    libero_python: Optional[str] = None,
    seed: int = 42,
    num_steps_wait: int = 10,
    device: str = "cuda",
    video_dir: Optional[str] = None,
    num_envs: int = 4,
    group_idx: int = 0,
    batch_server: Optional[BatchInferenceServer] = None,
    store_images: bool = False,
    group_size: int = 1,
    reward_coef: float = 1.0,
) -> List[ActionTokenEpisode]:
    """
    Collect G episodes using RLActionToken policy.

    Uses BatchInferenceServer for GPU inference: all num_envs env threads submit
    requests concurrently; a single background thread batches them into one GPU
    forward pass. This maximizes GPU utilization vs. the old model_lock (batch=1).

    Args:
        batch_server: shared server for this GPU (created by caller for cross-task
                      batching). If None, creates a local server for this call only.
        group_size: number of trajectories per initial state. G episodes are
                    split into G//group_size unique states, each repeated
                    group_size times. Default 1 = legacy behavior (no repeat).
    """
    from concurrent.futures import ThreadPoolExecutor, as_completed

    frozen_vla.eval()
    encoder.eval()
    actor.eval()
    critic.eval()

    # Use caller-provided server (shared across tasks on same GPU) or create locally
    _own_server = batch_server is None
    if _own_server:
        batch_server = BatchInferenceServer(
            frozen_vla=frozen_vla,
            encoder=encoder,
            actor=actor,
            critic=critic,
            device=device,
        ).start()

    n_workers = min(G, num_envs)
    # Each episode gets its own env (LiberoEnv is not thread-safe for reuse)
    envs = []
    for _ei in range(G):
        envs.append(LiberoEnv(libero_python=libero_python))
        if (_ei + 1) % 10 == 0 or _ei == G - 1:
            print(f"  envs created: {_ei+1}/{G}", flush=True)

    # Assign initial states: same-state grouping.
    # G episodes → G//group_size unique states, each repeated group_size times
    num_unique = max(1, G // group_size)
    _rng = np.random.RandomState(seed + group_idx)
    unique_states = _rng.randint(0, n_initial_states, size=num_unique)
    state_ids = np.repeat(unique_states, group_size)[:G]  # [s0,s0,s0,s0, s1,s1, ...]

    episodes = [None] * G
    try:
        with ThreadPoolExecutor(max_workers=n_workers) as pool:
            futures = {}
            for g in range(G):
                fut = pool.submit(
                    _action_token_rollout_one,
                    env=envs[g],
                    batch_server=batch_server,
                    suite_name=suite_name,
                    task_id=task_id,
                    state_idx=int(state_ids[g]),
                    action_norm_stats=action_norm_stats,
                    max_steps=max_steps,
                    chunk_len=chunk_len,
                    num_steps_wait=num_steps_wait,
                    seed=seed + g,
                    record_video=(video_dir is not None),
                    episode_idx=g,
                    group_idx=group_idx,
                    video_dir=video_dir,
                    store_images=store_images,
                    reward_coef=reward_coef,
                )
                futures[fut] = g
            done_count = 0
            success_count = 0
            for fut in as_completed(futures):
                g_idx = futures[fut]
                ep = fut.result()
                episodes[g_idx] = ep
                done_count += 1
                if ep.success:
                    success_count += 1
                print(f"  [rollout][dev={device}] ep {done_count}/{G} done "
                      f"({'SUCCESS' if ep.success else 'fail'}, "
                      f"{ep.env_steps} steps) "
                      f"[{success_count}/{done_count} success so far]", flush=True)
    finally:
        for env in envs:
            env.close()
        if _own_server:
            batch_server.stop()

    return episodes
compute_action_token_gae
compute_action_token_gae(episode: ActionTokenEpisode, gamma: float = 0.99, gae_lambda: float = 0.95)

Compute GAE advantages and returns for a single RLActionToken episode.

Returns:

Name Type Description
advantages

list of floats (len = finish_step)

returns

list of floats (len = finish_step)

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def compute_action_token_gae(
    episode: ActionTokenEpisode,
    gamma: float = 0.99,
    gae_lambda: float = 0.95,
):
    """
    Compute GAE advantages and returns for a single RLActionToken episode.

    Returns:
        advantages: list of floats (len = finish_step)
        returns: list of floats (len = finish_step)
    """
    steps = episode.step_records
    n = episode.finish_step
    if n == 0:
        return [], []

    values = [s.value for s in steps[:n]]
    # Terminal value = 0 (episode ended)
    advantages = [0.0] * n
    returns = [0.0] * n

    # Sparse reward at last step only
    rewards = [0.0] * n
    rewards[-1] = episode.reward

    gae = 0.0
    for t in reversed(range(n)):
        next_value = values[t + 1] if t + 1 < n else 0.0
        delta = rewards[t] + gamma * next_value - values[t]
        gae = delta + gamma * gae_lambda * gae
        advantages[t] = gae
        returns[t] = gae + values[t]

    return advantages, returns
action_token_ppo_loss
action_token_ppo_loss(encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, critic: ActionTokenCritic, episodes: List[ActionTokenEpisode], gamma: float = 0.99, gae_lambda: float = 0.95, clip_eps: float = 0.2, vf_coef: float = 0.5, recon_loss_coef: float = 0.0, frozen_vla=None, device: str = 'cuda')

Compute PPO loss on a batch of RLActionToken episodes.

Only encoder + actor + critic have gradients. Optionally add reconstruction loss as regularizer.

Returns:

Name Type Description
loss

scalar tensor

stats

dict with training metrics

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def action_token_ppo_loss(
    encoder: ActionTokenEncoderDecoder,
    actor: ActionTokenActor,
    critic: ActionTokenCritic,
    episodes: List[ActionTokenEpisode],
    gamma: float = 0.99,
    gae_lambda: float = 0.95,
    clip_eps: float = 0.2,
    vf_coef: float = 0.5,
    recon_loss_coef: float = 0.0,
    frozen_vla=None,
    device: str = "cuda",
):
    """
    Compute PPO loss on a batch of RLActionToken episodes.

    Only encoder + actor + critic have gradients.
    Optionally add reconstruction loss as regularizer.

    Returns:
        loss: scalar tensor
        stats: dict with training metrics
    """
    all_rl_tokens = []
    all_vla_actions = []
    all_actions_taken = []
    all_old_log_probs = []
    all_advantages = []
    all_returns = []
    all_old_values = []
    all_prop_states = []

    for ep in episodes:
        adv, ret = compute_action_token_gae(ep, gamma, gae_lambda)
        for t in range(ep.finish_step):
            step = ep.step_records[t]
            all_rl_tokens.append(step.rl_token)
            all_vla_actions.append(step.vla_action)
            all_actions_taken.append(step.action_taken)
            all_old_log_probs.append(step.old_log_prob)
            all_advantages.append(adv[t])
            all_returns.append(ret[t])
            all_old_values.append(step.value)
            prop = step.prop_state if step.prop_state is not None else torch.zeros(8)
            all_prop_states.append(prop)

    if not all_rl_tokens:
        return torch.tensor(0.0, device=device, requires_grad=True), {"n_steps": 0}

    # Stack to batched tensors
    rl_tokens = torch.stack(all_rl_tokens).to(device)          # (N, 1, D)
    vla_actions = torch.stack(all_vla_actions).to(device)      # (N, C, A)
    actions_taken = torch.stack(all_actions_taken).to(device)  # (N, C, A)
    old_lp = torch.tensor(all_old_log_probs, device=device)   # (N,)
    advantages = torch.tensor(all_advantages, device=device)   # (N,)
    returns = torch.tensor(all_returns, device=device)         # (N,)
    old_values = torch.tensor(all_old_values, device=device)   # (N,)
    prop_states = torch.stack(all_prop_states).to(device)      # (N, prop_dim)

    # Normalize advantages
    if advantages.numel() > 1:
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # New policy log probs
    new_lp = actor.log_prob_of(rl_tokens, vla_actions, actions_taken, prop_states)  # (N,)

    # PPO clipped policy loss
    ratio = torch.exp(new_lp - old_lp)
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages
    pg_loss = -torch.min(surr1, surr2).mean()

    # Value loss (clipped)
    new_values = critic(rl_tokens)  # (N,)
    v_clipped = old_values + torch.clamp(new_values - old_values, -10.0, 10.0)
    vf_loss = torch.max(
        (new_values - returns) ** 2,
        (v_clipped - returns) ** 2,
    ).mean()

    loss = pg_loss + vf_coef * vf_loss

    # Optional reconstruction regularization
    recon_loss_val = 0.0
    if recon_loss_coef > 0.0 and rl_tokens.size(0) > 0:
        # Re-encode to get reconstruction loss (encoder must be in train mode)
        # Note: rl_tokens were computed with no_grad during rollout, so we need
        # to recompute through encoder for gradient flow
        # This requires action_queries which we don't have — skip if not available
        pass

    stats = {
        "pg_loss": pg_loss.item(),
        "vf_loss": vf_loss.item(),
        "loss": loss.item(),
        "ratio_mean": ratio.mean().item(),
        "ratio_max": ratio.max().item(),
        "ratio_min": ratio.min().item(),
        "clip_frac": ((ratio - 1.0).abs() > clip_eps).float().mean().item(),
        "n_steps": len(all_rl_tokens),
        "advantage_mean": advantages.mean().item(),
        "value_mean": new_values.mean().item(),
    }
    return loss, stats
push_episodes_to_buffer
push_episodes_to_buffer(episodes: List[ActionTokenEpisode], replay_buffer: ReplayBuffer, gamma_per_step: float = 0.99)

Convert episode step_records into (s, a, r, s', done) transitions and push them into the replay buffer.

Chunk subsampling (paper, stride=2): For each executed chunk, we create transitions starting from positions [0, 2, 4, 6] within the chunk. Position-p transition: - state: rl_token at obs[p] within this chunk - action: action_taken[p:C] || next_chunk.action_taken[0:p] (length-C cross-chunk slice) - next_state: rl_token at obs[p] within the NEXT chunk - done: True only for terminal chunk

Reward (paper Eq. 3): Q̂ = Σ_{t'=0}^{C-1} γ^{t'} r_{t+t'} + γ^C Q_next. For sparse reward, only the terminal chunk has non-zero reward. At stride position p in terminal chunk (done at step d): reward = γ^(d-p) * episode_reward (discounted by steps from p to done)

Reward scheme: success=1.0, failure=0.0 (paper scheme)

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def push_episodes_to_buffer(
    episodes: List[ActionTokenEpisode],
    replay_buffer: ReplayBuffer,
    gamma_per_step: float = 0.99,
):
    """
    Convert episode step_records into (s, a, r, s', done) transitions and push
    them into the replay buffer.

    Chunk subsampling (paper, stride=2):
      For each executed chunk, we create transitions starting from positions
      [0, 2, 4, 6] within the chunk. Position-p transition:
        - state:      rl_token at obs[p] within this chunk
        - action:     action_taken[p:C] || next_chunk.action_taken[0:p]  (length-C cross-chunk slice)
        - next_state: rl_token at obs[p] within the NEXT chunk
        - done:       True only for terminal chunk

    Reward (paper Eq. 3): Q̂ = Σ_{t'=0}^{C-1} γ^{t'} r_{t+t'} + γ^C Q_next.
    For sparse reward, only the terminal chunk has non-zero reward.
    At stride position p in terminal chunk (done at step d):
        reward = γ^(d-p) * episode_reward   (discounted by steps from p to done)

    Reward scheme: success=1.0, failure=0.0 (paper scheme)
    """
    n_pushed = 0
    for ep in episodes:
        steps = ep.step_records
        n = ep.finish_step
        if n == 0:
            continue

        # Infer chunk_len from first step record
        chunk_len = steps[0].action_taken.shape[0]
        stride = 2
        stride_positions = list(range(0, chunk_len, stride))  # [0, 2, 4, 6] for C=8

        # Terminal step within last chunk (0-based): done_cache_idx was set after +1
        done_step = max(0, ep.done_cache_idx - 1) if ep.done_cache_idx >= 0 else chunk_len - 1

        for t in range(n):
            s = steps[t]
            is_last = (t == n - 1)
            done = is_last
            s_next = steps[t + 1] if not is_last else None

            for pos_idx, p in enumerate(stride_positions):
                # ── Terminal chunk: skip stride positions beyond the done step ──
                if is_last and p > done_step:
                    break

                # ── Current state at position p ──
                if p == 0:
                    rl_tok = s.rl_token
                    vla_act = s.vla_action
                    prop = s.prop_state
                else:
                    sub_idx = pos_idx - 1  # sub_tokens index (0→pos2, 1→pos4, 2→pos6)
                    if sub_idx >= len(s.sub_tokens):
                        # Episode ended before position p within this chunk; skip
                        break
                    rl_tok, vla_act, prop = s.sub_tokens[sub_idx]

                # ── Action: cross-chunk slice of length C ──
                if p == 0:
                    action = s.action_taken  # full chunk, no slicing needed
                else:
                    tail = s.action_taken[p:]           # (C-p, A)
                    if s_next is not None:
                        head = s_next.action_taken[:p]  # (p, A)
                    else:
                        # Terminal chunk: pad with zeros
                        head = torch.zeros(p, s.action_taken.shape[-1],
                                           dtype=s.action_taken.dtype)
                    action = torch.cat([tail, head], dim=0)  # (C, A)

                # ── Reward (paper Eq. 3): discounted within-chunk reward ──
                if is_last and ep.reward != 0.0:
                    # Sparse reward at terminal step: γ^(done_step - p) * R
                    stride_reward = (gamma_per_step ** (done_step - p)) * ep.reward
                else:
                    stride_reward = 0.0

                # ── Next state at position p ──
                if is_last:
                    next_rl_tok = torch.zeros_like(rl_tok)
                    next_vla_act = torch.zeros_like(vla_act)
                    next_prop = torch.zeros_like(prop) if prop is not None else None
                elif p == 0:
                    next_rl_tok = s_next.rl_token
                    next_vla_act = s_next.vla_action
                    next_prop = s_next.prop_state
                else:
                    sub_idx = pos_idx - 1
                    if sub_idx >= len(s_next.sub_tokens):
                        # Next chunk doesn't have sub_token at position p (ended early)
                        # Sub-positions p+2, p+4 also missing → stop this chunk's subsampling
                        break
                    next_rl_tok, next_vla_act, next_prop = s_next.sub_tokens[sub_idx]

                replay_buffer.push(
                    rl_token=rl_tok,
                    vla_action=vla_act,
                    action_taken=action,
                    reward=stride_reward,
                    next_rl_token=next_rl_tok,
                    next_vla_action=next_vla_act,
                    done=done,
                    task_id=ep.task_id,
                    prop_state=prop,
                    next_prop_state=next_prop,
                )
                n_pushed += 1
    return n_pushed
vla_finetune_step
vla_finetune_step(vla, encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, q_critic: ActionTokenQCritic, episodes: List[ActionTokenEpisode], beta: float = 0.1, device: str = 'cuda', micro_batch: int = 4)

Full fine-tune: re-run VLA forward on stored images with gradients enabled.

Gradient path: actor_loss → actor → rl_token → encoder → action_queries → VLA The critic is frozen during this step (only provides Q signal, no param update).

Parameters:

Name Type Description Default
episodes List[ActionTokenEpisode]

current iteration's episodes (must have .images stored)

required
micro_batch int

VLA forward batch size (controls GPU memory)

4
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def vla_finetune_step(
    vla,
    encoder: ActionTokenEncoderDecoder,
    actor: ActionTokenActor,
    q_critic: ActionTokenQCritic,
    episodes: List[ActionTokenEpisode],
    beta: float = 0.1,
    device: str = "cuda",
    micro_batch: int = 4,
):
    """
    Full fine-tune: re-run VLA forward on stored images with gradients enabled.

    Gradient path: actor_loss → actor → rl_token → encoder → action_queries → VLA
    The critic is frozen during this step (only provides Q signal, no param update).

    Args:
        episodes: current iteration's episodes (must have .images stored)
        micro_batch: VLA forward batch size (controls GPU memory)
    """
    # Collect all (images, instruction, prop) from step records
    all_imgs, all_instrs, all_props = [], [], []
    for ep in episodes:
        for sr in ep.step_records:
            if sr.images is not None:
                all_imgs.append(sr.images)
                all_instrs.append(sr.instruction)
                all_props.append(sr.prop_state)

    if not all_imgs:
        return torch.tensor(0.0, device=device, requires_grad=True), {}

    # Freeze critic params (we only want gradients for VLA/encoder/actor)
    critic_was_training = q_critic.training
    for p in q_critic.parameters():
        p.requires_grad_(False)

    total_loss = 0.0
    total_q = 0.0
    total_bc = 0.0
    n_batches = 0

    for i in range(0, len(all_imgs), micro_batch):
        batch_imgs = all_imgs[i:i + micro_batch]
        batch_instr = all_instrs[i:i + micro_batch]
        batch_props = all_props[i:i + micro_batch]
        B = len(batch_imgs)

        # VLA forward WITH gradients (the whole point of full fine-tune)
        with torch.autocast("cuda", dtype=torch.bfloat16):
            action_queries, vla_actions = vla.get_vla_action(
                batch_images=batch_imgs, instructions=batch_instr)
        # action_queries: (B, chunk_len, H) with grad through VLA

        rl_tokens = encoder.encode(action_queries)  # (B, 1, D) with grad

        props_t = torch.stack([p.float() for p in batch_props]).to(device) if batch_props[0] is not None else None
        actions, _ = actor(rl_tokens, vla_actions, props_t, deterministic=False)

        # Q provides gradient signal to actions and rl_tokens (but critic params frozen)
        q_val = q_critic.q1_forward(rl_tokens, actions, props_t)
        bc_penalty = ((actions - vla_actions) ** 2).sum(dim=(-2, -1)).mean()
        loss = -q_val.mean() + beta * bc_penalty
        loss.backward()

        total_loss += loss.item()
        total_q += q_val.mean().item()
        total_bc += bc_penalty.item()
        n_batches += 1

    # Restore critic
    for p in q_critic.parameters():
        p.requires_grad_(True)
    if critic_was_training:
        q_critic.train()

    n = max(n_batches, 1)
    stats = {
        "vla_loss": total_loss / n,
        "vla_q": total_q / n,
        "vla_bc": total_bc / n,
        "vla_n_samples": len(all_imgs),
    }
    return stats
action_token_td_critic_update
action_token_td_critic_update(actor: ActionTokenActor, q_critic: ActionTokenQCritic, target_q_critic: ActionTokenQCritic, replay_buffer: ReplayBuffer, batch_size: int = 256, gamma: float = 0.99, device: str = 'cuda', target_noise_std: float = 0.2, target_noise_clip: float = 0.5, n_tasks: int = 0, target_actor: ActionTokenActor = None)

TD3-style twin-Q critic update from replay buffer (Eq. 3 in paper).

Target: Q̂ = Σ γ^t' r_t' + γ^C * min(Q1', Q2')(s', a') where a' ~ π_target(·|s', ã') + clipped noise (target policy smoothing)

Parameters:

Name Type Description Default
n_tasks int

if > 0, use per-task stratified sampling for balanced multi-task update.

0
target_actor ActionTokenActor

Polyak-averaged actor for computing next actions (TD3). Falls back to online actor if None.

None

Returns:

Name Type Description
critic_loss

scalar tensor (with grad on q_critic params)

stats

dict

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def action_token_td_critic_update(
    actor: ActionTokenActor,
    q_critic: ActionTokenQCritic,
    target_q_critic: ActionTokenQCritic,
    replay_buffer: ReplayBuffer,
    batch_size: int = 256,
    gamma: float = 0.99,
    device: str = "cuda",
    target_noise_std: float = 0.2,
    target_noise_clip: float = 0.5,
    n_tasks: int = 0,
    target_actor: ActionTokenActor = None,
):
    """
    TD3-style twin-Q critic update from replay buffer (Eq. 3 in paper).

    Target: Q̂ = Σ γ^t' r_t' + γ^C * min(Q1', Q2')(s', a')
    where a' ~ π_target(·|s', ã') + clipped noise  (target policy smoothing)

    Args:
        n_tasks: if > 0, use per-task stratified sampling for balanced multi-task update.
        target_actor: Polyak-averaged actor for computing next actions (TD3).
                      Falls back to online actor if None.

    Returns:
        critic_loss: scalar tensor (with grad on q_critic params)
        stats: dict
    """
    if n_tasks > 0:
        rl_tok, vla_act, act_taken, rew, next_rl_tok, next_vla_act, done, prop, next_prop = \
            replay_buffer.sample_balanced(batch_size, n_tasks=n_tasks, device=device)
    else:
        rl_tok, vla_act, act_taken, rew, next_rl_tok, next_vla_act, done, prop, next_prop = \
            replay_buffer.sample(batch_size, device=device)

    # ── Target Q value (TD3: target policy smoothing + min of twin Q) ──
    with torch.no_grad():
        # Next action from target actor (TD3) + smoothing noise
        _actor_for_target = target_actor if target_actor is not None else actor
        next_action, _ = _actor_for_target(next_rl_tok, next_vla_act, next_prop, deterministic=True)
        # TD3 target policy smoothing: add clipped noise to target actions
        noise = torch.randn_like(next_action) * target_noise_std
        noise = noise.clamp(-target_noise_clip, target_noise_clip)
        next_action = (next_action + noise).clamp(-1.0, 1.0)

        # Target Q with min of twin Q (paper: Eq.3)
        tq1, tq2 = target_q_critic(next_rl_tok, next_action, next_prop)
        next_q = torch.min(tq1, tq2)  # (B,)
        target = rew + gamma * next_q * (1.0 - done)
        # Clip to theoretical upper bound (paper reward: success=1, so Q ≤ 1/(1-γ)).
        # Prevents bootstrap overestimation from runaway positive Q values.
        # Clip to theoretical upper bound: max_reward / (1 - gamma)
        # With reward_coef=5, gamma=0.99 → upper bound = 500. Use reward_coef as safe proxy.
        q_upper = max(1.0, rew.abs().max().item() * 2) if rew.numel() > 0 else 1.0
        target = target.clamp(max=q_upper)

    # ── Online Q loss ──
    q1, q2 = q_critic(rl_tok, act_taken, prop)
    critic_loss = F.mse_loss(q1, target) + F.mse_loss(q2, target)

    stats = {
        "critic_loss": critic_loss.item(),
        "q1_mean": q1.mean().item(),
        "q2_mean": q2.mean().item(),
        "target_mean": target.mean().item(),
    }
    return critic_loss, stats
action_token_td_actor_update
action_token_td_actor_update(actor: ActionTokenActor, q_critic: ActionTokenQCritic, replay_buffer: ReplayBuffer, batch_size: int = 256, beta: float = 1.0, device: str = 'cuda', n_tasks: int = 0)

DDPG-style actor update from the RL Token paper (Eq. 5):

L_π(θ) = E[ -Q_ψ(x, a) + β ‖a - ã‖² ]

where a ~ π_θ (stochastic rsample, paper Eq. 5). With fixed small std=0.1, the gradient direction is nearly identical to the deterministic mean, but we match the paper's formulation exactly.

Parameters:

Name Type Description Default
n_tasks int

if > 0, use per-task stratified sampling for balanced multi-task update.

0

Returns:

Name Type Description
actor_loss

scalar tensor (with grad on actor params)

stats

dict

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def action_token_td_actor_update(
    actor: ActionTokenActor,
    q_critic: ActionTokenQCritic,
    replay_buffer: ReplayBuffer,
    batch_size: int = 256,
    beta: float = 1.0,
    device: str = "cuda",
    n_tasks: int = 0,
):
    """
    DDPG-style actor update from the RL Token paper (Eq. 5):

    L_π(θ) = E[ -Q_ψ(x, a) + β ‖a - ã‖² ]

    where a ~ π_θ (stochastic rsample, paper Eq. 5). With fixed small std=0.1,
    the gradient direction is nearly identical to the deterministic mean, but we
    match the paper's formulation exactly.

    Args:
        n_tasks: if > 0, use per-task stratified sampling for balanced multi-task update.

    Returns:
        actor_loss: scalar tensor (with grad on actor params)
        stats: dict
    """
    if n_tasks > 0:
        rl_tok, vla_act, _, _, _, _, _, prop, _ = \
            replay_buffer.sample_balanced(batch_size, n_tasks=n_tasks, device=device)
    else:
        rl_tok, vla_act, _, _, _, _, _, prop, _ = \
            replay_buffer.sample(batch_size, device=device)

    # Paper Eq. 5: a ~ π_θ (stochastic rsample for correct gradient)
    action, _ = actor(rl_tok, vla_act, prop, deterministic=False)

    # Q-value of the sampled action (only Q1 for efficiency, as in TD3)
    q_val = q_critic.q1_forward(rl_tok, action, prop)  # (B,)

    # BC regularization: ‖a - ã‖² (anchor to VLA reference)
    bc_penalty = ((action - vla_act) ** 2).sum(dim=(-2, -1)).mean()  # scalar

    # Paper Eq. 5: minimize -Q + β * BC
    actor_loss = -q_val.mean() + beta * bc_penalty

    stats = {
        "actor_loss": actor_loss.item(),
        "q_actor_mean": q_val.mean().item(),
        "bc_penalty": bc_penalty.item(),
    }
    return actor_loss, stats
action_token_td_update
action_token_td_update(actor: ActionTokenActor, critic, replay_buffer: ReplayBuffer, batch_size: int = 256, gamma: float = 0.99, device: str = 'cuda', target_critic=None, beta: float = 1.0, update_actor: bool = True, target_noise_std: float = 0.2, target_noise_clip: float = 0.5)

Combined TD3-style update step (backward compat wrapper).

If critic is ActionTokenQCritic: uses the new TD3/DDPG-style from paper. If critic is ActionTokenCritic (legacy V(s)): falls back to old logic.

Parameters:

Name Type Description Default
beta float

BC regularization coefficient (paper Eq. 5)

1.0
update_actor bool

If False, only update critic (TD3 delayed actor update)

True
target_noise_std float

Std of noise added to target policy actions

0.2
target_noise_clip float

Clip range for target policy noise

0.5

Returns:

Name Type Description
loss

scalar tensor

stats

dict

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_trainer.py
def action_token_td_update(
    actor: ActionTokenActor,
    critic,  # ActionTokenQCritic or legacy ActionTokenCritic
    replay_buffer: ReplayBuffer,
    batch_size: int = 256,
    gamma: float = 0.99,
    device: str = "cuda",
    target_critic=None,
    beta: float = 1.0,
    update_actor: bool = True,
    target_noise_std: float = 0.2,
    target_noise_clip: float = 0.5,
):
    """
    Combined TD3-style update step (backward compat wrapper).

    If critic is ActionTokenQCritic: uses the new TD3/DDPG-style from paper.
    If critic is ActionTokenCritic (legacy V(s)): falls back to old logic.

    Args:
        beta: BC regularization coefficient (paper Eq. 5)
        update_actor: If False, only update critic (TD3 delayed actor update)
        target_noise_std: Std of noise added to target policy actions
        target_noise_clip: Clip range for target policy noise

    Returns:
        loss: scalar tensor
        stats: dict
    """
    if isinstance(critic, ActionTokenQCritic):
        # ── New TD3-style from paper ──
        # Critic update
        critic_loss, c_stats = action_token_td_critic_update(
            actor=actor,
            q_critic=critic,
            target_q_critic=target_critic,
            replay_buffer=replay_buffer,
            batch_size=batch_size,
            gamma=gamma,
            device=device,
            target_noise_std=target_noise_std,
            target_noise_clip=target_noise_clip,
        )

        if update_actor:
            # Actor update (DDPG + BC regularization)
            actor_loss, a_stats = action_token_td_actor_update(
                actor=actor,
                q_critic=critic,
                replay_buffer=replay_buffer,
                batch_size=batch_size,
                beta=beta,
                device=device,
            )
            loss = critic_loss + actor_loss
            stats = {**c_stats, **a_stats, "td_loss": loss.item()}
        else:
            loss = critic_loss
            stats = {**c_stats, "actor_loss": 0.0, "td_loss": loss.item()}

        return loss, stats

    else:
        # ── Legacy V(s) path (backward compat for PPO-based code) ──
        rl_tok, vla_act, act_taken, rew, next_rl_tok, next_vla_act, done, prop, next_prop = \
            replay_buffer.sample(batch_size, device=device)

        with torch.no_grad():
            value_net = target_critic if target_critic is not None else critic
            next_value = value_net(next_rl_tok)  # (B,)
            target = rew + gamma * next_value * (1.0 - done)

        value = critic(rl_tok)
        critic_loss = F.mse_loss(value, target)

        advantage = (target - value).detach()
        if advantage.numel() > 1:
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

        log_prob = actor.log_prob_of(
            rl_tok.unsqueeze(1) if rl_tok.dim() == 2 else rl_tok,
            vla_act,
            act_taken,
            prop if hasattr(actor, 'prop_dim') and actor.prop_dim > 0 else None,
        )
        actor_loss = -(advantage * log_prob).mean()

        loss = actor_loss + 0.5 * critic_loss

        stats = {
            "actor_loss": actor_loss.item(),
            "critic_loss": critic_loss.item(),
            "td_loss": loss.item(),
            "value_mean": value.mean().item(),
            "advantage_mean": advantage.mean().item(),
            "target_mean": target.mean().item(),
        }
        return loss, stats

Fast rollout

action_token_rollout_fast

Fast ActionToken rollout — step-lock architecture.

all envs move in lockstep.
  1. Batch VLA forward for ALL active envs (one GPU call)
  2. Batch encoder + actor
  3. ALL envs execute chunk in parallel threads
  4. Collect results, repeat

No BatchInferenceServer needed. No async queuing. No batch fragmentation.

Speedup: ~50x vs original (env creation + batch fragmentation eliminated).

action_token_collect_group_steplock
action_token_collect_group_steplock(env_pool: PersistentEnvPool, frozen_vla, encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, critic: ActionTokenCritic, suite_name: str, task_id: int, n_initial_states: int, action_norm_stats: dict, max_steps: int, chunk_len: int, G: int = 64, seed: int = 42, num_steps_wait: int = 10, device: str = 'cuda', video_dir: Optional[str] = None, group_idx: int = 0, store_images: bool = False, group_size: int = 1, reward_coef: float = 1.0, actor_chunk_len: int = None, env_offset: int = 0, warmup_mode: bool = False) -> List[ActionTokenEpisode]

Collect G episodes using step-lock architecture.

All envs move in lockstep
  1. Batch VLA forward (one GPU call for ALL active envs)
  2. Batch encoder + actor (or skip actor if warmup_mode)
  3. All envs execute chunk in parallel
  4. Repeat

Parameters:

Name Type Description Default
env_offset int

starting env index in the pool

0
actor_chunk_len int

if set, actor outputs shorter chunk than VLA

None
warmup_mode bool

if True, use VLA actions directly (skip actor). For buffer pre-fill.

False
Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_rollout_fast.py
@torch.no_grad()
def action_token_collect_group_steplock(
    env_pool: PersistentEnvPool,
    frozen_vla,
    encoder: ActionTokenEncoderDecoder,
    actor: ActionTokenActor,
    critic: ActionTokenCritic,
    suite_name: str,
    task_id: int,
    n_initial_states: int,
    action_norm_stats: dict,
    max_steps: int,
    chunk_len: int,
    G: int = 64,
    seed: int = 42,
    num_steps_wait: int = 10,
    device: str = "cuda",
    video_dir: Optional[str] = None,
    group_idx: int = 0,
    store_images: bool = False,
    group_size: int = 1,
    reward_coef: float = 1.0,
    actor_chunk_len: int = None,
    env_offset: int = 0,
    warmup_mode: bool = False,
) -> List[ActionTokenEpisode]:
    """
    Collect G episodes using step-lock architecture.

    All envs move in lockstep:
      1. Batch VLA forward (one GPU call for ALL active envs)
      2. Batch encoder + actor (or skip actor if warmup_mode)
      3. All envs execute chunk in parallel
      4. Repeat

    Args:
        env_offset: starting env index in the pool
        actor_chunk_len: if set, actor outputs shorter chunk than VLA
        warmup_mode: if True, use VLA actions directly (skip actor). For buffer pre-fill.
    """
    if actor_chunk_len is None:
        actor_chunk_len = chunk_len
    exec_chunk_len = actor_chunk_len  # how many steps to execute per chunk

    frozen_vla.eval()
    encoder.eval()
    actor.eval()

    # Assign initial states (same-state grouping)
    num_unique = max(1, G // group_size)
    _rng = np.random.RandomState(seed + group_idx)
    unique_states = _rng.randint(0, n_initial_states, size=num_unique)
    state_ids = np.repeat(unique_states, group_size)[:G]

    n_workers = min(G, len(env_pool))

    # ── Phase 1: Reset all envs in parallel ──
    from concurrent.futures import as_completed as _as_completed
    obs_list = [None] * G
    with ThreadPoolExecutor(max_workers=G) as _pool:
        _futs = {_pool.submit(env_pool.reset_env, env_offset + g, suite_name, task_id, int(state_ids[g]), seed + g): g for g in range(G)}
        for _f in _as_completed(_futs):
            obs_list[_futs[_f]] = _f.result()
    print(f"  reset done: {G} envs (parallel)", flush=True)

    task_descriptions = [env_pool.envs[env_offset + g].task_description for g in range(G)]

    # ── Phase 2: Warmup dummy steps (parallel) ──
    if num_steps_wait > 0:
        with ThreadPoolExecutor(max_workers=G) as _pool:
            _futs = {_pool.submit(_env_dummy_steps, env_pool, env_offset + g, num_steps_wait): g for g in range(G)}
            for _f in _as_completed(_futs):
                obs_list[_futs[_f]] = _f.result()

    # ── Phase 3: Step-lock main loop ──
    episodes = [ActionTokenEpisode(task_id=task_id, state_idx=int(state_ids[g])) for g in range(G)]
    active = [True] * G  # which envs are still running
    env_steps = [0] * G
    all_frames = [[] for _ in range(G)]  # video frames

    max_chunks = max_steps // exec_chunk_len + 1

    # Timing accumulators
    _t_vla_forward = 0.0
    _t_encoder_actor = 0.0
    _t_store_records = 0.0
    _t_unnormalize = 0.0
    _t_env_step = 0.0
    _t_total_chunks = 0
    _t_rollout_start = time.time()

    for chunk_idx in range(max_chunks):
        active_ids = [g for g in range(G) if active[g]]
        if not active_ids:
            break

        _t_chunk_start = time.time()

        # ── Step 1: Batch VLA forward for all active envs ──
        _t0 = time.time()
        batch_images = [[obs_list[g]["primary_image"], obs_list[g]["wrist_image"]] for g in active_ids]
        batch_instrs = [task_descriptions[g] for g in active_ids]
        batch_props = [np.array(obs_list[g]["state"], dtype=np.float32) for g in active_ids]

        print(f"  [VLA forward] batch={len(batch_images)}, active_envs={len(active_ids)}", flush=True)
        with torch.autocast("cuda", dtype=torch.bfloat16):
            action_queries, vla_actions = frozen_vla.get_vla_action(
                batch_images=batch_images, instructions=batch_instrs)
        torch.cuda.synchronize()
        _t1 = time.time()
        _t_vla_forward += _t1 - _t0
        print(f"  [VLA done] aq={action_queries.shape} va={vla_actions.shape} time={_t1-_t0:.3f}s", flush=True)

        # ── Step 2: Batch encoder + actor ──
        _t0 = time.time()
        rl_tokens = encoder.encode(action_queries)  # (N_active, 1, D)

        props_t = torch.tensor(np.array(batch_props), dtype=torch.float32).to(device)

        # Slice VLA actions for actor if actor uses shorter chunk
        if actor_chunk_len < vla_actions.size(1):
            vla_actions_for_actor = vla_actions[:, :actor_chunk_len, :]
        else:
            vla_actions_for_actor = vla_actions

        if warmup_mode:
            # Warmup: use VLA actions directly, skip actor (like BatchInferenceServer.warmup_mode)
            actions_t = vla_actions_for_actor
            log_probs = torch.zeros(len(active_ids), device=device)
            values = torch.zeros(len(active_ids), device=device)
        else:
            actions_t, log_probs = actor(rl_tokens, vla_actions_for_actor, props_t, deterministic=False)
            values = critic(rl_tokens)  # (N_active,)
        torch.cuda.synchronize()
        _t1 = time.time()
        _t_encoder_actor += _t1 - _t0

        # Convert to numpy
        actions_np = actions_t.cpu().numpy()  # (N_active, exec_chunk_len, action_dim)
        vla_actions_cpu = vla_actions_for_actor.cpu()

        # ── Store step records ──
        _t0 = time.time()
        for i, g in enumerate(active_ids):
            sr = ActionTokenStepRecord(
                rl_token=rl_tokens[i:i+1].cpu().squeeze(0),
                vla_action=vla_actions_cpu[i],
                action_taken=actions_t[i].detach().cpu(),
                old_log_prob=log_probs[i].item() if log_probs is not None else 0.0,
                value=values[i].item(),
                prop_state=torch.tensor(batch_props[i]),
                images=[obs_list[g]["primary_image"].copy(), obs_list[g]["wrist_image"].copy()] if store_images else None,
                instruction=task_descriptions[g] if store_images else None,
            )
            episodes[g].step_records.append(sr)
        _t1 = time.time()
        _t_store_records += _t1 - _t0

        # ── Step 3: Unnormalize actions ──
        _t0 = time.time()
        action_chunks_unnorm = []
        for i in range(len(active_ids)):
            action_chunks_unnorm.append(_unnormalize(actions_np[i], action_norm_stats))
        _t1 = time.time()
        _t_unnormalize += _t1 - _t0

        # ── Step 4: All envs execute chunk in parallel ──
        _t0 = time.time()
        record_video = video_dir is not None
        with ThreadPoolExecutor(max_workers=len(active_ids)) as _pool:
            _futs = {}
            for i, g in enumerate(active_ids):
                _futs[_pool.submit(
                    _env_step_chunk, env_pool, env_offset + g, action_chunks_unnorm[i],
                    exec_chunk_len, record_video
                )] = (i, g)
            for _f in _as_completed(_futs):
                i, g = _futs[_f]
                obs, reward, done, steps_taken, frames = _f.result()
                obs_list[g] = obs
                env_steps[g] += steps_taken
                if record_video:
                    all_frames[g].extend(frames)
                if done or env_steps[g] >= max_steps:
                    active[g] = False
                    ep = episodes[g]
                    ep.success = bool(done and reward > 0.5)
                    ep.reward = reward_coef if ep.success else 0.0
                    ep.done_cache_idx = steps_taken
                    ep.finish_step = len(ep.step_records)
                    ep.env_steps = env_steps[g]
        _t1 = time.time()
        _t_env_step += _t1 - _t0
        _t_total_chunks += 1

        print(
            f"[TIMING] chunk {chunk_idx}: active={len(active_ids)} | "
            f"vla={_t_vla_forward/_t_total_chunks:.3f}s  enc+act={_t_encoder_actor/_t_total_chunks:.3f}s  "
            f"store={_t_store_records/_t_total_chunks:.3f}s  unnorm={_t_unnormalize/_t_total_chunks:.3f}s  "
            f"env_step={_t_env_step/_t_total_chunks:.3f}s  "
            f"chunk_total={time.time()-_t_chunk_start:.3f}s"
        )

    # ── Timing summary ──
    _t_rollout_total = time.time() - _t_rollout_start
    if _t_total_chunks > 0:
        print(
            f"\n[TIMING SUMMARY] rollout group {group_idx} | G={G} | {_t_total_chunks} chunks | total={_t_rollout_total:.2f}s\n"
            f"  vla_forward:    {_t_vla_forward:.2f}s ({100*_t_vla_forward/_t_rollout_total:.1f}%)  avg={_t_vla_forward/_t_total_chunks:.3f}s/chunk\n"
            f"  encoder+actor:  {_t_encoder_actor:.2f}s ({100*_t_encoder_actor/_t_rollout_total:.1f}%)  avg={_t_encoder_actor/_t_total_chunks:.3f}s/chunk\n"
            f"  store_records:  {_t_store_records:.2f}s ({100*_t_store_records/_t_rollout_total:.1f}%)  avg={_t_store_records/_t_total_chunks:.3f}s/chunk\n"
            f"  unnormalize:    {_t_unnormalize:.2f}s ({100*_t_unnormalize/_t_rollout_total:.1f}%)  avg={_t_unnormalize/_t_total_chunks:.3f}s/chunk\n"
            f"  env_step:       {_t_env_step:.2f}s ({100*_t_env_step/_t_rollout_total:.1f}%)  avg={_t_env_step/_t_total_chunks:.3f}s/chunk\n"
            f"  other/overhead: {_t_rollout_total - _t_vla_forward - _t_encoder_actor - _t_store_records - _t_unnormalize - _t_env_step:.2f}s"
        )

    # ── Finalize episodes ──
    for g in range(G):
        ep = episodes[g]
        if ep.finish_step == 0:  # timeout, never set
            ep.finish_step = len(ep.step_records)
            ep.env_steps = env_steps[g]
            ep.reward = 0.0

        if all_frames[g] and video_dir is not None:
            from AlphaBrain.training.reinforcement_learning.common.rollout import _save_video
            os.makedirs(video_dir, exist_ok=True)
            status = "success" if ep.success else "fail"
            vpath = os.path.join(video_dir,
                                 f"g{group_idx:04d}_e{g:02d}_t{task_id}_s{int(state_ids[g]):02d}_{status}.mp4")
            ep.video_path = _save_video(all_frames[g], vpath)

    return episodes
action_token_collect_multitask_steplock
action_token_collect_multitask_steplock(env_pool: PersistentEnvPool, frozen_vla, encoder: ActionTokenEncoderDecoder, actor: ActionTokenActor, critic: ActionTokenCritic, suite_name: str, task_ids: List[int], n_initial_states: int, action_norm_stats: dict, max_steps: int, chunk_len: int, G_per_task: int = 8, seed: int = 42, num_steps_wait: int = 10, device: str = 'cuda', group_idx: int = 0, store_images: bool = False, group_size: int = 1, reward_coef: float = 1.0, actor_chunk_len: int = None, warmup_mode: bool = False) -> List[ActionTokenEpisode]

Collect episodes for MULTIPLE tasks on ONE GPU in a single step-lock loop.

All tasks' envs are merged into one batch for VLA forward — no per-task threading, no CUDA concurrency issues, maximum GPU batch utilization.

Parameters:

Name Type Description Default
task_ids List[int]

list of task IDs to run on this GPU

required
G_per_task int

episodes per task

8

Returns: flat list of all episodes across all tasks

Source code in AlphaBrain/training/reinforcement_learning/algos/RLActionToken/action_token_rollout_fast.py
@torch.no_grad()
def action_token_collect_multitask_steplock(
    env_pool: PersistentEnvPool,
    frozen_vla,
    encoder: ActionTokenEncoderDecoder,
    actor: ActionTokenActor,
    critic: ActionTokenCritic,
    suite_name: str,
    task_ids: List[int],
    n_initial_states: int,
    action_norm_stats: dict,
    max_steps: int,
    chunk_len: int,
    G_per_task: int = 8,
    seed: int = 42,
    num_steps_wait: int = 10,
    device: str = "cuda",
    group_idx: int = 0,
    store_images: bool = False,
    group_size: int = 1,
    reward_coef: float = 1.0,
    actor_chunk_len: int = None,
    warmup_mode: bool = False,
) -> List[ActionTokenEpisode]:
    """
    Collect episodes for MULTIPLE tasks on ONE GPU in a single step-lock loop.

    All tasks' envs are merged into one batch for VLA forward — no per-task
    threading, no CUDA concurrency issues, maximum GPU batch utilization.

    Args:
        task_ids: list of task IDs to run on this GPU
        G_per_task: episodes per task
    Returns:
        flat list of all episodes across all tasks
    """
    if actor_chunk_len is None:
        actor_chunk_len = chunk_len
    exec_chunk_len = actor_chunk_len

    frozen_vla.eval()
    encoder.eval()
    actor.eval()

    n_tasks = len(task_ids)
    total_G = G_per_task * n_tasks

    # Assign states per task
    _rng = np.random.RandomState(seed + group_idx)
    all_state_ids = []
    all_task_labels = []  # which task each episode belongs to
    for tid in task_ids:
        num_unique = max(1, G_per_task // group_size)
        states = _rng.randint(0, n_initial_states, size=num_unique)
        states = np.repeat(states, group_size)[:G_per_task]
        all_state_ids.extend(states)
        all_task_labels.extend([tid] * G_per_task)

    n_workers = min(total_G, len(env_pool))

    # ── Phase 1: Reset all envs in parallel ──
    from concurrent.futures import as_completed as _as_completed
    obs_list = [None] * total_G
    with ThreadPoolExecutor(max_workers=total_G) as _pool:
        _futs = {_pool.submit(env_pool.reset_env, g, suite_name, all_task_labels[g], int(all_state_ids[g]), seed + g): g for g in range(total_G)}
        for _f in _as_completed(_futs):
            obs_list[_futs[_f]] = _f.result()
    print(f"  reset done: {total_G} envs (parallel)", flush=True)

    task_descriptions = [env_pool.envs[g].task_description for g in range(total_G)]

    # ── Phase 2: Warmup (parallel) ──
    if num_steps_wait > 0:
        with ThreadPoolExecutor(max_workers=total_G) as _pool:
            _futs = {_pool.submit(_env_dummy_steps, env_pool, g, num_steps_wait): g for g in range(total_G)}
            for _f in _as_completed(_futs):
                obs_list[_futs[_f]] = _f.result()

    # ── Phase 3: Step-lock main loop (ALL tasks merged) ──
    episodes = [ActionTokenEpisode(task_id=all_task_labels[g], state_idx=int(all_state_ids[g]))
                for g in range(total_G)]
    active = [True] * total_G
    env_steps = [0] * total_G
    max_chunks = max_steps // exec_chunk_len + 1

    _t_vla = 0.0
    _t_env = 0.0
    _n_chunks = 0

    for chunk_idx in range(max_chunks):
        active_ids = [g for g in range(total_G) if active[g]]
        if not active_ids:
            break

        # ── ONE batched VLA forward for ALL active envs across ALL tasks ──
        _t0 = time.time()
        batch_images = [[obs_list[g]["primary_image"], obs_list[g]["wrist_image"]] for g in active_ids]
        batch_instrs = [task_descriptions[g] for g in active_ids]
        batch_props = [np.array(obs_list[g]["state"], dtype=np.float32) for g in active_ids]

        with torch.autocast("cuda", dtype=torch.bfloat16):
            action_queries, vla_actions = frozen_vla.get_vla_action(
                batch_images=batch_images, instructions=batch_instrs)

        rl_tokens = encoder.encode(action_queries)
        props_t = torch.tensor(np.array(batch_props), dtype=torch.float32).to(device)

        if actor_chunk_len < vla_actions.size(1):
            vla_actions_for_actor = vla_actions[:, :actor_chunk_len, :]
        else:
            vla_actions_for_actor = vla_actions

        if warmup_mode:
            actions_t = vla_actions_for_actor
            log_probs = torch.zeros(len(active_ids), device=device)
            values = torch.zeros(len(active_ids), device=device)
        else:
            actions_t, log_probs = actor(rl_tokens, vla_actions_for_actor, props_t, deterministic=False)
            values = critic(rl_tokens)
        _t_vla += time.time() - _t0

        actions_np = actions_t.cpu().numpy()
        vla_actions_cpu = vla_actions_for_actor.cpu()

        # Store records
        for i, g in enumerate(active_ids):
            episodes[g].step_records.append(ActionTokenStepRecord(
                rl_token=rl_tokens[i:i+1].cpu().squeeze(0),
                vla_action=vla_actions_cpu[i],
                action_taken=actions_t[i].detach().cpu(),
                old_log_prob=log_probs[i].item() if log_probs is not None else 0.0,
                value=values[i].item(),
                prop_state=torch.tensor(batch_props[i]),
            ))

        # Unnormalize
        action_chunks_unnorm = [_unnormalize(actions_np[i], action_norm_stats) for i in range(len(active_ids))]

        # ── ALL envs execute chunk in parallel ──
        _t0 = time.time()
        with ThreadPoolExecutor(max_workers=len(active_ids)) as _pool:
            _futs = {}
            for i, g in enumerate(active_ids):
                _futs[_pool.submit(_env_step_chunk, env_pool, g, action_chunks_unnorm[i],
                                   exec_chunk_len, False)] = (i, g)
            for _f in _as_completed(_futs):
                i, g = _futs[_f]
                obs, reward, done, steps_taken, _ = _f.result()
                obs_list[g] = obs
                env_steps[g] += steps_taken
                if done or env_steps[g] >= max_steps:
                    active[g] = False
                    ep = episodes[g]
                    ep.success = bool(done and reward > 0.5)
                    ep.reward = reward_coef if ep.success else 0.0
                    ep.done_cache_idx = steps_taken
                    ep.finish_step = len(ep.step_records)
                    ep.env_steps = env_steps[g]
        _t_env += time.time() - _t0
        _n_chunks += 1

    # Finalize
    for g in range(total_G):
        ep = episodes[g]
        if ep.finish_step == 0:
            ep.finish_step = len(ep.step_records)
            ep.env_steps = env_steps[g]
            ep.reward = 0.0

    if _n_chunks > 0:
        print(f"[MULTITASK TIMING] {n_tasks} tasks × {G_per_task} eps = {total_G} total | "
                     f"{_n_chunks} chunks | vla={_t_vla:.1f}s env={_t_env:.1f}s "
                     f"total={_t_vla+_t_env:.1f}s")

    return episodes

Shared components (common/)

Rollout

rollout

Episode rollout for GRPO.

Key design
  • Each episode stores trajectory as tensors: (traj_len, ...) for batched forward
  • finish_step tracks actual episode length for masking
  • Multiprocess env workers (one process per env)
  • Gaussian policy: a ~ N(μ, σ²I), log_prob = -||a-μ||²/(2σ²)
StepRecord dataclass
StepRecord(primary_image: ndarray, wrist_image: ndarray, instruction: str, norm_action: ndarray, old_log_prob: float, value: float = 0.0, action_token_ids: ndarray = None)

One inference step = one action-chunk prediction.

Replay buffer

replay_buffer

Off-policy replay buffer for RLActionToken training (RL Token paper style).

Stores transitions as detached CPU tensors in a ring buffer. Each transition stores: - rl_token (state), action_taken, reward, next_rl_token (next state), done - vla_ref_action: the VLA reference action chunk (for BC regularization) - next_vla_action: VLA reference at next state (for target policy sampling) - prop_state / next_prop_state: proprioceptive state (eef_pos+axisangle+gripper, 8-dim) - task_id: integer task index for per-task stratified sampling

ReplayBuffer
ReplayBuffer(capacity: int = 100000)

Fixed-capacity ring buffer for off-policy experience replay.

Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
def __init__(self, capacity: int = 100_000):
    self.capacity = capacity
    self.buffer = []
    self.pos = 0
    # task_id index: task_id -> list of buffer positions (for stratified sampling)
    self._task_index: dict = defaultdict(list)
push
push(rl_token: Tensor, vla_action: Tensor, action_taken: Tensor, reward: float, next_rl_token: Tensor, next_vla_action: Tensor, done: bool, task_id: int = 0, prop_state: Optional[Tensor] = None, next_prop_state: Optional[Tensor] = None)

Store a single transition (all tensors detached to CPU).

Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
def push(
    self,
    rl_token: torch.Tensor,        # (1, D) or (D,)
    vla_action: torch.Tensor,       # (chunk_len, action_dim) — VLA reference
    action_taken: torch.Tensor,     # (chunk_len, action_dim) — actual executed
    reward: float,
    next_rl_token: torch.Tensor,    # (1, D) or (D,)
    next_vla_action: torch.Tensor,  # (chunk_len, action_dim)
    done: bool,
    task_id: int = 0,
    prop_state: Optional[torch.Tensor] = None,       # (prop_dim,)
    next_prop_state: Optional[torch.Tensor] = None,  # (prop_dim,)
):
    """Store a single transition (all tensors detached to CPU)."""
    # Default zero prop states if not provided
    if prop_state is None:
        prop_state = torch.zeros(8, dtype=torch.float32)
    if next_prop_state is None:
        next_prop_state = torch.zeros(8, dtype=torch.float32)

    transition = (
        rl_token.detach().cpu(),
        vla_action.detach().cpu(),
        action_taken.detach().cpu(),
        torch.tensor(reward, dtype=torch.float32),
        next_rl_token.detach().cpu(),
        next_vla_action.detach().cpu(),
        torch.tensor(float(done), dtype=torch.float32),
        prop_state.detach().cpu(),
        next_prop_state.detach().cpu(),
        task_id,  # stored as plain int at index 9, stripped in _collect
    )
    if len(self.buffer) < self.capacity:
        idx = len(self.buffer)
        self.buffer.append(transition)
    else:
        idx = self.pos
        # Remove old task_id index entry for overwritten slot
        old_task_id = self.buffer[idx][9]
        old_list = self._task_index[old_task_id]
        # Efficiently remove the old index (swap with last element)
        try:
            pos_in_list = old_list.index(idx)
            old_list[pos_in_list] = old_list[-1]
            old_list.pop()
        except ValueError:
            pass
        self.buffer[idx] = transition

    self._task_index[task_id].append(idx)
    self.pos = (self.pos + 1) % self.capacity
sample
sample(batch_size: int, device: str = 'cuda') -> Tuple[torch.Tensor, ...]

Sample a random mini-batch (uniform over all transitions).

Returns:

Type Description
Tensor

Tuple of (rl_tokens, vla_actions, actions_taken, rewards, next_rl_tokens, next_vla_actions, dones, prop_states, next_prop_states)

...

Each tensor has batch dim prepended and is moved to device.

Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
def sample(
    self,
    batch_size: int,
    device: str = "cuda",
) -> Tuple[torch.Tensor, ...]:
    """
    Sample a random mini-batch (uniform over all transitions).

    Returns:
        Tuple of (rl_tokens, vla_actions, actions_taken, rewards,
                   next_rl_tokens, next_vla_actions, dones,
                   prop_states, next_prop_states)
        Each tensor has batch dim prepended and is moved to `device`.
    """
    indices = np.random.choice(len(self.buffer), batch_size, replace=False)
    return self._collect(indices, device)
sample_balanced
sample_balanced(batch_size: int, n_tasks: int, device: str = 'cuda') -> Tuple[torch.Tensor, ...]

Per-task stratified sampling: sample equal number of transitions from each task.

Ensures each task contributes equally to each gradient update, matching GRPO's equal-frequency multi-task update property.

Parameters:

Name Type Description Default
batch_size int

total transitions to sample

required
n_tasks int

number of tasks (0..n_tasks-1)

required
device str

target device

'cuda'
Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
def sample_balanced(
    self,
    batch_size: int,
    n_tasks: int,
    device: str = "cuda",
) -> Tuple[torch.Tensor, ...]:
    """
    Per-task stratified sampling: sample equal number of transitions from each task.

    Ensures each task contributes equally to each gradient update, matching
    GRPO's equal-frequency multi-task update property.

    Args:
        batch_size: total transitions to sample
        n_tasks: number of tasks (0..n_tasks-1)
        device: target device
    """
    per_task = max(1, batch_size // n_tasks)
    all_indices = []

    for tid in range(n_tasks):
        pool = self._task_index.get(tid, [])
        if not pool:
            continue
        n_sample = min(per_task, len(pool))
        chosen = np.random.choice(pool, n_sample, replace=(n_sample > len(pool)))
        all_indices.extend(chosen.tolist())

    if not all_indices:
        # Fallback to uniform if index is empty
        return self.sample(batch_size, device)

    # Trim or pad to batch_size
    if len(all_indices) > batch_size:
        all_indices = all_indices[:batch_size]

    return self._collect(all_indices, device)
is_ready
is_ready(min_size: int = 256) -> bool

Whether buffer has enough samples for at least one batch.

Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
def is_ready(self, min_size: int = 256) -> bool:
    """Whether buffer has enough samples for at least one batch."""
    return len(self.buffer) >= min_size
task_counts
task_counts() -> dict

Return number of transitions per task (for diagnostics).

Source code in AlphaBrain/training/reinforcement_learning/common/replay_buffer.py
def task_counts(self) -> dict:
    """Return number of transitions per task (for diagnostics)."""
    return {tid: len(indices) for tid, indices in self._task_index.items()}

Checkpoint I/O

ckpt_io

Checkpoint save helper shared by all RLActionToken training phases.


Environments (envs/)

LIBERO environment

libero_env

LIBERO environment proxy — runs in the VLA Python environment.

Spawns libero_env_worker.py as a subprocess using LIBERO_PYTHON (the separate conda env that has libero installed), then communicates via stdin/stdout with length-prefixed msgpack messages.

Usage matches the original direct API

env = LiberoEnv(suite_name, task_id, seed) obs = env.reset(initial_state_idx=0) obs, reward, done = env.step(action_7d) env.close()

LiberoEnv
LiberoEnv(libero_python: Optional[str] = None)

Proxy to a LIBERO environment running in a separate Python process.

The worker process is started once per LiberoEnv instance and reused across reset() calls (different tasks can be loaded with reset).

Parameters:

Name Type Description Default
libero_python Optional[str]

Path to the LIBERO conda env Python binary. Defaults to LIBERO_PYTHON env var, then 'python'.

None
Source code in AlphaBrain/training/reinforcement_learning/envs/libero_env.py
def __init__(
    self,
    libero_python: Optional[str] = None,
):
    """
    Args:
        libero_python: Path to the LIBERO conda env Python binary.
                       Defaults to LIBERO_PYTHON env var, then 'python'.
    """
    python_bin = (
        libero_python
        or os.environ.get("LIBERO_PYTHON", "python")
    )

    # Inherit the current env so LIBERO can find its own packages.
    # Inject LIBERO_HOME into PYTHONPATH so the editable install is not required.
    worker_env = os.environ.copy()
    libero_home = os.environ.get("LIBERO_HOME", "")
    if libero_home:
        existing = worker_env.get("PYTHONPATH", "")
        worker_env["PYTHONPATH"] = f"{libero_home}:{existing}" if existing else libero_home

    self._proc = subprocess.Popen(
        [python_bin, _WORKER_SCRIPT],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        env=worker_env,
    )

    self.task_description: str = ""
    self.max_steps: int = 300
    self._closed = False
reset
reset(suite_name: str, task_id: int, initial_state_idx: int = 0, seed: int = 42) -> dict

Reset the environment to a specific task and initial state.

Returns obs dict
  • "primary_image" : PIL.Image
  • "wrist_image" : PIL.Image
  • "state" : np.ndarray (8,)
Source code in AlphaBrain/training/reinforcement_learning/envs/libero_env.py
def reset(
    self,
    suite_name: str,
    task_id: int,
    initial_state_idx: int = 0,
    seed: int = 42,
) -> dict:
    """
    Reset the environment to a specific task and initial state.

    Returns obs dict:
      - "primary_image"  : PIL.Image
      - "wrist_image"    : PIL.Image
      - "state"          : np.ndarray (8,)
    """
    _write_msg(self._proc, {
        "cmd":               "reset",
        "task_suite":        suite_name,
        "task_id":           task_id,
        "initial_state_idx": initial_state_idx,
        "seed":              seed,
    })
    resp = _read_msg(self._proc)
    _check_resp(resp)

    self.task_description = resp["task_description"]
    self.max_steps = resp["max_steps"]
    return _parse_obs(resp["obs"])
step
step(action_7d: ndarray) -> Tuple[dict, float, bool]

Execute one env step.

Returns:

Name Type Description
obs_dict dict

parsed observation

reward float

0.0 / 1.0

done bool

episode termination flag

Source code in AlphaBrain/training/reinforcement_learning/envs/libero_env.py
def step(self, action_7d: np.ndarray) -> Tuple[dict, float, bool]:
    """
    Execute one env step.

    Returns:
        obs_dict  : parsed observation
        reward    : 0.0 / 1.0
        done      : episode termination flag
    """
    _write_msg(self._proc, {"cmd": "step", "action": action_7d.tolist()})
    resp = _read_msg(self._proc)
    _check_resp(resp)
    return _parse_obs(resp["obs"]), float(resp["reward"]), bool(resp["done"])
get_suite_info
get_suite_info(suite_name: str, libero_python: Optional[str] = None) -> dict

Query task count and task names from the LIBERO worker without opening an environment.

Returns: {"n_tasks": int, "task_names": [str, ...]}

Source code in AlphaBrain/training/reinforcement_learning/envs/libero_env.py
def get_suite_info(suite_name: str, libero_python: Optional[str] = None) -> dict:
    """
    Query task count and task names from the LIBERO worker without
    opening an environment.

    Returns: {"n_tasks": int, "task_names": [str, ...]}
    """
    python_bin = libero_python or os.environ.get("LIBERO_PYTHON", "python")
    script = (
        "import sys; _real_stdout = sys.stdout; sys.stdout = sys.stderr; "
        "from libero.libero import benchmark; "
        f"s = benchmark.get_benchmark_dict()['{suite_name}'](); "
        "sys.stdout = _real_stdout; "
        "import json; "
        "json.dump({'n_tasks': s.n_tasks, "
        "'task_names': [s.get_task(i).language for i in range(s.n_tasks)]}, sys.stdout)"
    )
    run_env = os.environ.copy()
    libero_home = os.environ.get("LIBERO_HOME", "")
    if libero_home:
        existing = run_env.get("PYTHONPATH", "")
        run_env["PYTHONPATH"] = f"{libero_home}:{existing}" if existing else libero_home
    result = subprocess.run(
        [python_bin, "-c", script],
        capture_output=True, text=True, timeout=30,
        env=run_env,
    )
    if result.returncode != 0:
        raise RuntimeError(f"get_suite_info failed:\n{result.stderr}")
    import json
    return json.loads(result.stdout)

LIBERO environment workers

libero_env_worker

LIBERO environment worker — runs inside the LIBERO Python environment.

Launched as a subprocess by libero_env.py (VLA Python env). Communicates via stdin/stdout using length-prefixed msgpack messages.

Protocol (both directions): [4-byte little-endian uint32 length][msgpack payload]

Commands received (from parent): {"cmd": "reset", "task_suite": str, "task_id": int, "initial_state_idx": int, "seed": int} {"cmd": "step", "action": [7 floats]}

Responses sent (to parent): {"status": "ok", "obs": {"primary": , "wrist": , "state": [8 floats]}, "reward": float, "done": bool}

libero_env_worker_fast

Fast LIBERO environment worker — socket pair IPC, MuJoCo env reuse.

Launched by persistent_env_pool.py. Receives socket FD as argv[1]. Communicates via socket (not stdin/stdout) — eliminates pipe buffer deadlock.

Key optimization: same task → only reset() + set_init_state() (no env recreation).

Persistent environment pool

persistent_env_pool

Persistent LiberoEnv pool — keeps subprocess envs alive across rollout iterations.

IPC: socket pair (bidirectional, with settimeout) instead of stdin/stdout pipes. This eliminates the pipe buffer deadlock that occurs with high-concurrency pipe I/O.

Two-layer optimization
  1. Subprocess pool: LiberoEnv subprocesses created once, reused across iterations
  2. Fast worker: libero_env_worker_fast.py reuses MuJoCo env for same task
PersistentEnvPool
PersistentEnvPool(num_envs: int, libero_python: Optional[str] = None, egl_gpu_id: Optional[int] = None)

Pool of persistent fast LiberoEnv subprocess workers.

Each worker subprocess stays alive for the entire training run. On reset with same task: only reset() + set_init_state() in worker (fast). On reset with different task: worker recreates MuJoCo env (slower, but rare).

Source code in AlphaBrain/training/reinforcement_learning/envs/persistent_env_pool.py
def __init__(
    self,
    num_envs: int,
    libero_python: Optional[str] = None,
    egl_gpu_id: Optional[int] = None,
):
    self.num_envs = num_envs
    self.libero_python = libero_python
    self.envs: List[_FastLiberoEnv] = []

    gpu_label = f", EGL GPU={egl_gpu_id}" if egl_gpu_id is not None else ""
    print(f"Creating {num_envs} persistent fast LiberoEnv workers{gpu_label}...", flush=True)
    for i in range(num_envs):
        env = _FastLiberoEnv(libero_python=libero_python, egl_gpu_id=egl_gpu_id)
        self.envs.append(env)
        if (i + 1) % 10 == 0 or i == num_envs - 1:
            print(f"  env workers: {i+1}/{num_envs}", flush=True)
    print(f"PersistentEnvPool ready: {num_envs} workers", flush=True)
reset_env
reset_env(env_idx: int, suite_name: str, task_id: int, state_idx: int, seed: int) -> dict

Reset a single env to given task/state. Returns obs dict.

Source code in AlphaBrain/training/reinforcement_learning/envs/persistent_env_pool.py
def reset_env(
    self,
    env_idx: int,
    suite_name: str,
    task_id: int,
    state_idx: int,
    seed: int,
) -> dict:
    """Reset a single env to given task/state. Returns obs dict."""
    return self.envs[env_idx].reset(
        suite_name=suite_name,
        task_id=task_id,
        initial_state_idx=state_idx,
        seed=seed,
    )
step_env
step_env(env_idx: int, action: ndarray)

Step a single env. Returns (obs, reward, done).

Source code in AlphaBrain/training/reinforcement_learning/envs/persistent_env_pool.py
def step_env(self, env_idx: int, action: np.ndarray):
    """Step a single env. Returns (obs, reward, done)."""
    return self.envs[env_idx].step(action)
close
close()

Close all subprocess workers.

Source code in AlphaBrain/training/reinforcement_learning/envs/persistent_env_pool.py
def close(self):
    """Close all subprocess workers."""
    for env in self.envs:
        try:
            env.close()
        except Exception:
            pass
    self.envs.clear()
    print("PersistentEnvPool closed", flush=True)

Evaluation (eval/)

eval_libero

Standalone offline eval for an RLActionToken iter checkpoint on LIBERO.

Loads
  • frozen QwenOFT VLA from --vla_ckpt
  • ActionTokenEncoderDecoder from /encoder.pt
  • ActionTokenActor from /actor.pt

Runs deterministic eval across all (or selected) tasks of the suite, prints per-task SR, and optionally appends the result to a JSON file.

The eval protocol is shared with training via AlphaBrain.training.reinforcement_learning.eval.eval_helpers._eval_deterministic_local.

eval_helpers

Deterministic eval helpers shared by training loops and standalone eval.

aggregate_shards

Aggregate per-shard eval JSONs into a single summary.

Each shard JSON (produced by eval_libero.py --results_json shard_X.json) contains a per_task_sr dict for the subset of tasks that shard ran. This script merges them into one summary with all per-task SRs and the overall SR (mean across tasks).

Usage

python AlphaBrain/training/reinforcement_learning/eval/aggregate_shards.py \ --out_dir

\ --action_token_ckpt \ --vla_ckpt \ --suite libero_goal \ --n_eps 50


Training entrypoints (trainers/)

Shared CLI arguments

train_args

CLI args for RLActionToken training (all three phases).

Main entrypoint

train

ActionToken training entry for QwenOFT on LIBERO.

Three phases

--phase pretrain Encoder-decoder pretraining via reconstruction loss --phase rl On-policy multi-GPU PPO update (legacy; PPO/GRPO is a TODO) --phase rl_offpolicy Off-policy TD3 with split rollout/training GPUs (production)

Usage
Phase 1: Encoder pretraining

python AlphaBrain/training/reinforcement_learning/trainers/train.py --phase pretrain --ckpt_path results/training/my_sft/final_model --suite libero_goal --task_id 0

Phase 2 (production): off-policy TD3

python AlphaBrain/training/reinforcement_learning/trainers/train.py --phase rl_offpolicy --ckpt_path results/training/my_sft/final_model --encoder_path results/action_token_training/pretrain/checkpoints/pretrain_best/encoder.pt --suite libero_goal --task_id 0

Pretrain

train_pretrain

Phase 1: encoder-decoder pretraining via reconstruction loss.

On-policy RL

train_rl_onpolicy

Phase 2 (legacy on-policy variant): multi-GPU rollout + PPO update.

NOTE: this is the legacy on-policy training path. The off-policy TD3 variant (train_rl_offpolicy.run_rl_offpolicy) is the production code path used by every release run; this file is kept for reference.

TODO: implement proper PPO / GRPO updates here. The current action_token_ppo_loss is a placeholder that mixes value-loss + clipped surrogate; before relying on this phase for new experiments, port a battle-tested PPO loop (importance sampling, KL early stop, value clipping, gradient accumulation, group-relative normalization for GRPO, etc.) from a reference implementation.

run_rl
run_rl(args)

Phase 2 on-policy: multi-GPU parallel rollout + PPO update.

Each GPU loads a frozen VLA copy and runs its own env workers to collect episodes in parallel. All episodes are gathered, then the tiny RLActionToken network update happens on every rank (identical, since network is tiny).

6 GPUs = 6× rollout throughput (the actual bottleneck is CPU env.step).

Source code in AlphaBrain/training/reinforcement_learning/trainers/train_rl_onpolicy.py
def run_rl(args):
    """Phase 2 on-policy: multi-GPU parallel rollout + PPO update.

    Each GPU loads a frozen VLA copy and runs its own env workers to collect
    episodes in parallel. All episodes are gathered, then the tiny RLActionToken
    network update happens on every rank (identical, since network is tiny).

    6 GPUs = 6× rollout throughput (the actual bottleneck is CPU env.step).
    """
    set_seed(args.seed)
    accelerator = Accelerator()
    device = accelerator.device
    rank = accelerator.process_index
    world_size = accelerator.num_processes
    is_main = accelerator.is_main_process

    # Each rank loads its own frozen VLA copy (~12GB per GPU)
    logger.info(f"[rank {rank}/{world_size}] Loading frozen VLA from {args.ckpt_path}")
    frozen_vla = BaseFramework.from_pretrained(args.ckpt_path)
    frozen_vla = frozen_vla.to(torch.bfloat16).to(device).eval()
    for param in frozen_vla.parameters():
        param.requires_grad_(False)

    hidden_dim = frozen_vla.qwen_vl_interface.model.config.hidden_size
    chunk_len = frozen_vla.chunk_len
    action_dim = frozen_vla.config.framework.action_model.action_dim

    _norm_stats = frozen_vla.norm_stats
    _unnorm_key = next(iter(_norm_stats.keys()))
    action_norm_stats = _norm_stats[_unnorm_key]["action"]

    suite_info = get_suite_info(args.suite)
    n_tasks = suite_info["n_tasks"]
    max_steps = MAX_STEPS[args.suite]

    # Create RLActionToken modules (tiny, same on all ranks)
    enc_dec = ActionTokenEncoderDecoder(
        input_dim=hidden_dim,
        bottleneck_dim=args.bottleneck_dim,
        chunk_len=chunk_len,
        num_heads=args.encoder_heads,
        encoder_layers=args.encoder_layers,
        decoder_layers=args.encoder_layers,
    ).to(device)

    if args.encoder_path:
        logger.info(f"[rank {rank}] Loading pretrained encoder from {args.encoder_path}")
        state = torch.load(args.encoder_path, map_location=device)
        enc_dec.load_state_dict(state)

    actor = ActionTokenActor(
        bottleneck_dim=args.bottleneck_dim,
        action_dim=action_dim,
        chunk_len=chunk_len,
        hidden_dim=args.actor_hidden_dim,
        ref_dropout=args.ref_dropout,
    ).to(device)

    critic = ActionTokenCritic(
        bottleneck_dim=args.bottleneck_dim,
        hidden_dim=args.critic_hidden_dim,
    ).to(device)

    if is_main:
        enc_params = sum(p.numel() for p in enc_dec.parameters())
        actor_params = sum(p.numel() for p in actor.parameters())
        critic_params = sum(p.numel() for p in critic.parameters())
        vla_params = sum(p.numel() for p in frozen_vla.parameters())
        logger.info(f"Frozen VLA: {vla_params / 1e9:.2f}B params × {world_size} GPUs")
        logger.info(f"RLActionToken trainable: encoder={enc_params / 1e6:.2f}M, "
                    f"actor={actor_params / 1e6:.2f}M, critic={critic_params / 1e6:.2f}M")
        logger.info(f"Rollout parallelism: {world_size} ranks × {args.num_envs} envs × "
                    f"{args.G} episodes/rank = {world_size * args.G} episodes/iter")

    # Optimizer
    param_groups = [
        {"params": actor.parameters(), "lr": args.lr_actor},
        {"params": critic.parameters(), "lr": args.lr_critic},
    ]
    if args.lr_encoder > 0:
        param_groups.append({"params": enc_dec.parameters(), "lr": args.lr_encoder})
    else:
        for p in enc_dec.parameters():
            p.requires_grad_(False)

    optimizer = torch.optim.AdamW(param_groups, betas=(0.9, 0.95), weight_decay=1e-8)

    # WandB (main rank only)
    if args.use_wandb and is_main:
        run_name = args.run_name or f"action_token_rl_{args.suite}_task{args.task_id}"
        wandb.init(project=args.wandb_project, name=run_name,
                   config={**vars(args), "chunk_len": chunk_len,
                           "hidden_dim": hidden_dim, "action_dim": action_dim,
                           "world_size": world_size})

    video_dir = Path(args.output_dir) / "videos"
    metrics_history = []
    best_sr = 0.0
    best_eval_sr = 0.0
    running_sr = []
    total_env_steps = 0  # cumulative environment steps (sample steps)

    # ── Training loop ──────────────────────────────────────
    for iteration in range(1, args.max_iter + 1):
        if is_main:
            logger.info(f"{'='*60}")
            logger.info(f"[iter {iteration}/{args.max_iter}] Collecting "
                         f"{args.G}×{world_size}={args.G * world_size} episodes across {world_size} GPUs...")

        save_video = (args.save_video_interval > 0 and
                      (iteration == 1 or iteration % args.save_video_interval == 0))
        iter_video_dir = (str(video_dir / f"iter_{iteration:05d}")
                          if save_video and is_main else None)

        task_id = args.task_id if args.task_id >= 0 else random.randint(0, n_tasks - 1)

        # ── Each rank collects G episodes in parallel ────────
        group_seed = args.seed + iteration * 1000 + rank * 100
        local_episodes = action_token_collect_group(
            frozen_vla=frozen_vla,
            encoder=enc_dec,
            actor=actor,
            critic=critic,
            suite_name=args.suite,
            task_id=task_id,
            n_initial_states=50,
            action_norm_stats=action_norm_stats,
            max_steps=max_steps,
            chunk_len=chunk_len,
            G=args.G,
            libero_python=os.environ.get("LIBERO_PYTHON"),
            seed=group_seed,
            num_steps_wait=args.num_steps_wait,
            device=str(device),
            video_dir=iter_video_dir,
            num_envs=args.num_envs,
            group_idx=iteration * world_size + rank,
            group_size=args.group_size,
            reward_coef=args.reward_coef,
        )

        # Gather rewards from all ranks for global stats
        local_rewards = torch.tensor(
            [ep.reward for ep in local_episodes], device=device, dtype=torch.float32)
        global_rewards = accelerator.gather(local_rewards).cpu().numpy()

        success_rate = float(np.mean(global_rewards > 0.5))
        mean_reward = float(np.mean(global_rewards))
        mean_steps = np.mean([ep.finish_step for ep in local_episodes])
        # Accumulate env steps: gather local env_steps across all ranks
        local_env_steps = torch.tensor(
            sum(ep.env_steps for ep in local_episodes), device=device, dtype=torch.long)
        global_env_steps = accelerator.reduce(local_env_steps, reduction="sum").item()
        total_env_steps += int(global_env_steps)
        running_sr.append(success_rate)
        if len(running_sr) > 20:
            running_sr.pop(0)
        running_sr_avg = np.mean(running_sr)
        best_sr = max(best_sr, success_rate)

        if is_main:
            logger.info(f"[iter {iteration}] SR={success_rate:.2f} (best={best_sr:.2f}, "
                         f"running_avg={running_sr_avg:.2f}) reward={mean_reward:.2f} "
                         f"steps={mean_steps:.1f} ({len(global_rewards)} total episodes)")

        # ── PPO update on local episodes (each rank independently) ────
        # Since the network is tiny and identical across ranks, each rank
        # computes gradients on its own local episodes. We average gradients
        # across ranks for consistency.
        if is_main:
            logger.info(f"[iter {iteration}] PPO update ({args.ppo_epochs} epochs)...")
        actor.train()
        critic.train()
        if args.lr_encoder > 0:
            enc_dec.train()

        epoch_stats = []
        for ppo_epoch in range(args.ppo_epochs):
            optimizer.zero_grad()
            loss, stats = action_token_ppo_loss(
                encoder=enc_dec,
                actor=actor,
                critic=critic,
                episodes=local_episodes,
                gamma=args.gamma,
                gae_lambda=args.gae_lambda,
                clip_eps=args.clip_eps,
                vf_coef=args.vf_coef,
                recon_loss_coef=args.recon_loss_coef,
                frozen_vla=frozen_vla,
                device=str(device),
            )
            loss.backward()
            # Average gradients across ranks
            for p in list(actor.parameters()) + list(critic.parameters()):
                if p.grad is not None:
                    torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)
            if args.lr_encoder > 0:
                for p in enc_dec.parameters():
                    if p.grad is not None:
                        torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG)
            if args.max_grad_norm > 0:
                all_params = list(actor.parameters()) + list(critic.parameters())
                if args.lr_encoder > 0:
                    all_params += list(enc_dec.parameters())
                torch.nn.utils.clip_grad_norm_(all_params, args.max_grad_norm)
            optimizer.step()
            epoch_stats.append(stats)

        # ── Deterministic Eval (distributed across all ranks) ──────
        eval_sr = None
        eval_result = None
        do_eval = (args.eval_interval > 0
                   and iteration % args.eval_interval == 0)
        if do_eval:
            if is_main:
                logger.info(f"[iter {iteration}] Running distributed eval "
                             f"({args.eval_n_episodes} episodes across {world_size} GPUs)...")
            eval_video_dir = str(video_dir / f"eval_iter_{iteration:05d}") if save_video else None
            eval_result = _eval_distributed(
                accelerator=accelerator,
                frozen_vla=frozen_vla,
                encoder=enc_dec,
                actor=actor,
                suite_name=args.suite,
                task_id=task_id,
                action_norm_stats=action_norm_stats,
                max_steps=max_steps,
                chunk_len=chunk_len,
                n_episodes=args.eval_n_episodes,
                num_steps_wait=args.num_steps_wait,
                seed=args.seed,
                device=str(device),
                video_dir=eval_video_dir,
            )
            if is_main and eval_result:
                eval_sr = eval_result["eval_sr"]
                best_eval_sr = max(best_eval_sr, eval_sr)
                logger.info(f"  [eval] SR={eval_sr:.2%} (best_eval={best_eval_sr:.2%})")
                for sid, sr in eval_result["per_state"].items():
                    logger.info(f"    state {sid}: {sr:.2%}")

        # ── Logging (main rank only) ──────────────────────────
        if iteration % args.log_interval == 0 and is_main:
            avg = lambda k: float(np.mean([s[k] for s in epoch_stats if k in s]))
            log_entry = {
                "iter": iteration,
                "total_env_steps": total_env_steps,
                "success_rate": success_rate,
                "best_success_rate": best_sr,
                "running_avg_sr": running_sr_avg,
                "mean_reward": mean_reward,
                "loss": avg("loss"),
                "pg_loss": avg("pg_loss"),
                "vf_loss": avg("vf_loss"),
                "ratio_mean": avg("ratio_mean"),
                "clip_frac": avg("clip_frac"),
                "advantage_mean": avg("advantage_mean"),
                "value_mean": avg("value_mean"),
                "n_steps": avg("n_steps"),
            }
            if eval_sr is not None:
                log_entry["eval_sr"] = eval_sr
                log_entry["best_eval_sr"] = best_eval_sr
            metrics_history.append(log_entry)
            logger.info(f"  loss={log_entry['loss']:.4f} pg={log_entry['pg_loss']:.4f} "
                         f"vf={log_entry['vf_loss']:.4f} ratio={log_entry['ratio_mean']:.3f} "
                         f"clip_frac={log_entry['clip_frac']:.3f} "
                         f"total_env_steps={total_env_steps}")

            if args.use_wandb:
                wandb_log = {
                    "rollout/success_rate": success_rate,
                    "rollout/best_success_rate": best_sr,
                    "rollout/running_avg_sr": running_sr_avg,
                    "rollout/mean_reward": mean_reward,
                    "rollout/total_env_steps": total_env_steps,
                    "rollout/iter_env_steps": int(global_env_steps),
                    "train/loss": log_entry["loss"],
                    "train/pg_loss": log_entry["pg_loss"],
                    "train/vf_loss": log_entry["vf_loss"],
                    "train/ratio_mean": log_entry["ratio_mean"],
                    "train/clip_frac": log_entry["clip_frac"],
                    "train/advantage_mean": log_entry["advantage_mean"],
                    "train/value_mean": log_entry["value_mean"],
                    "train/n_steps": log_entry["n_steps"],
                }
                if eval_sr is not None:
                    wandb_log["eval/success_rate"] = eval_sr
                    wandb_log["eval/best_success_rate"] = best_eval_sr
                    for sid, sr in eval_result["per_state"].items():
                        wandb_log[f"eval/state_{sid:02d}"] = sr
                for ep in sorted(local_episodes, key=lambda e: -e.success):
                    if ep.video_path and os.path.exists(ep.video_path):
                        status = "success" if ep.success else "fail"
                        wandb_log[f"video/{status}"] = wandb.Video(
                            ep.video_path, fps=10, format="mp4")
                        break
                wandb.log(wandb_log, step=iteration)

        # ── Checkpoint (main rank only) ──────────────────────
        if iteration % args.save_interval == 0 and is_main:
            save_rlt_checkpoint(enc_dec, actor, critic,
                                iteration, args.output_dir, phase="rl")

        # Sync all ranks before next iteration
        accelerator.wait_for_everyone()

    # Final save
    if is_main:
        save_rlt_checkpoint(enc_dec, actor, critic,
                            args.max_iter, args.output_dir, phase="rl")
        metrics_path = Path(args.output_dir) / "metrics.json"
        metrics_path.parent.mkdir(parents=True, exist_ok=True)
        with open(metrics_path, "w") as f:
            json.dump(metrics_history, f, indent=2)
        logger.info(f"Done. Metrics -> {metrics_path}")

    if args.use_wandb and is_main:
        wandb.finish()

Off-policy RL

train_rl_offpolicy

Phase 2 (production): off-policy TD3 with split rollout/training GPUs.

Architecture (single process, no Accelerate): - Rollout GPUs: each loads a frozen VLA copy, collects episodes in parallel - Train GPU: runs actor-critic TD updates (backward pass) - Replay buffer: centralized on CPU, fed by all rollout GPUs - Eval: distributed across rollout GPUs for speed

run_rl_offpolicy
run_rl_offpolicy(args)

Phase 2 off-policy variant: split rollout/training GPUs.

Lets you freely assign GPUs, e.g.: --rollout_gpus 0,1,2,3,4 --train_gpu 5 Train GPU can overlap with rollout GPUs (rollout and training are sequential).

Source code in AlphaBrain/training/reinforcement_learning/trainers/train_rl_offpolicy.py
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
def run_rl_offpolicy(args):
    """Phase 2 off-policy variant: split rollout/training GPUs.

    Lets you freely assign GPUs, e.g.:
      --rollout_gpus 0,1,2,3,4  --train_gpu 5
    Train GPU can overlap with rollout GPUs (rollout and training are sequential).
    """
    set_seed(args.seed)

    # ── Parse GPU config ─────────────────────────────────
    if args.rollout_gpus is not None:
        rollout_gpu_ids = [int(g.strip()) for g in args.rollout_gpus.split(",")]
    else:
        rollout_gpu_ids = list(range(torch.cuda.device_count()))

    train_gpu_id = args.train_gpu if args.train_gpu is not None else rollout_gpu_ids[0]
    train_device = f"cuda:{train_gpu_id}"
    n_rollout_gpus = len(rollout_gpu_ids)

    logger.info(f"=== Off-Policy TD Mode (Split GPU) ===")
    logger.info(f"  Rollout GPUs: {rollout_gpu_ids}")
    logger.info(f"  Train GPU: {train_gpu_id}")

    # ── Load frozen VLA on each rollout GPU ──────────────
    vla_copies = {}
    for i, gpu_id in enumerate(rollout_gpu_ids):
        device = f"cuda:{gpu_id}"
        logger.info(f"  Loading frozen VLA on GPU {gpu_id} ({i+1}/{n_rollout_gpus})...")
        vla = BaseFramework.from_pretrained(args.ckpt_path)
        vla = vla.to(torch.bfloat16).to(device).eval()
        for p in vla.parameters():
            p.requires_grad_(False)
        vla_copies[gpu_id] = vla

    # Load VLA on train GPU (for eval, and optionally for fine-tuning)
    if args.finetune_vla:
        # Full fine-tune: need a SEPARATE trainable VLA on train GPU
        # (the rollout loop may have already loaded a frozen copy — replace it)
        logger.info(f"  Loading TRAINABLE VLA on train GPU {train_gpu_id} (full fine-tune)...")
        vla = BaseFramework.from_pretrained(args.ckpt_path)
        vla = vla.to(torch.bfloat16).to(train_device).train()
        if hasattr(vla, "qwen_vl_interface") and hasattr(vla.qwen_vl_interface, "model"):
            vla.qwen_vl_interface.model.gradient_checkpointing_enable()
        vla_copies[train_gpu_id] = vla
        logger.info(f"  Train GPU VLA: TRAINABLE ({sum(p.numel() for p in vla.parameters()) / 1e9:.2f}B params, gradient_checkpointing)")
    elif train_gpu_id not in vla_copies:
        logger.info(f"  Loading frozen VLA on train GPU {train_gpu_id}...")
        vla = BaseFramework.from_pretrained(args.ckpt_path)
        vla = vla.to(torch.bfloat16).to(train_device).eval()
        for p in vla.parameters():
            p.requires_grad_(False)
        vla_copies[train_gpu_id] = vla

    # Get model config from any VLA copy
    ref_vla = vla_copies[rollout_gpu_ids[0]]
    hidden_dim = ref_vla.qwen_vl_interface.model.config.hidden_size
    chunk_len = ref_vla.chunk_len
    action_dim = ref_vla.config.framework.action_model.action_dim
    _norm_stats = ref_vla.norm_stats
    _unnorm_key = next(iter(_norm_stats.keys()))
    action_norm_stats = _norm_stats[_unnorm_key]["action"]

    # Actor chunk length: paper uses C < H (e.g. VLA H=50, actor C=10)
    # For LIBERO: VLA chunk=8, actor chunk=4 (re-plan every 4 steps)
    actor_chunk_len = args.actor_chunk_len if args.actor_chunk_len else chunk_len
    logger.info(f"VLA chunk_len={chunk_len}, actor_chunk_len={actor_chunk_len}")

    suite_info = get_suite_info(args.suite)
    n_tasks = suite_info["n_tasks"]
    max_steps = MAX_STEPS[args.suite]

    # ── Handle --task_ids: subset of tasks treated like --all_tasks ──
    if args.task_ids is not None:
        selected_task_ids = [int(x) for x in args.task_ids.split(",")]
        logger.info(f"--task_ids={args.task_ids}: training on tasks {selected_task_ids}")
        args.all_tasks = True
        # Remap: override n_tasks and suite_info so all_tasks branch iterates
        # only over the selected tasks. We patch range(n_tasks) → selected_task_ids
        # by storing the list and overriding the iteration below.
        args._selected_task_ids = selected_task_ids
    else:
        args._selected_task_ids = None

    # ── Create trainable modules on train_gpu ────────────
    enc_dec = ActionTokenEncoderDecoder(
        input_dim=hidden_dim,
        bottleneck_dim=args.bottleneck_dim,
        chunk_len=chunk_len,
        num_heads=args.encoder_heads,
        encoder_layers=args.encoder_layers,
        decoder_layers=args.encoder_layers,
    ).to(train_device)

    if args.encoder_path:
        logger.info(f"  Loading pretrained encoder from {args.encoder_path}")
        state = torch.load(args.encoder_path, map_location=train_device)
        enc_dec.load_state_dict(state)

    if args.finetune_vla:
        # Full fine-tune: encoder trainable (updated via VLA fine-tune step)
        enc_dec.train()
        logger.info("  Encoder: TRAINABLE (full fine-tune mode)")
    else:
        # Freeze encoder for off-policy (buffer rl_tokens must stay valid)
        enc_dec.eval()
        for p in enc_dec.parameters():
            p.requires_grad_(False)

    actor = ActionTokenActor(
        bottleneck_dim=args.bottleneck_dim,
        action_dim=action_dim,
        chunk_len=actor_chunk_len,
        hidden_dim=args.actor_hidden_dim,
        ref_dropout=args.ref_dropout,
        fixed_std=args.fixed_std,
        prop_dim=8,  # paper: x = (z_rl, s_p) where s_p = eef_pos(3)+axisangle(3)+gripper(2)
    ).to(train_device)

    # Paper: Q(s, a) twin-Q critic (TD3 style)
    q_critic = ActionTokenQCritic(
        bottleneck_dim=args.bottleneck_dim,
        action_dim=action_dim,
        chunk_len=actor_chunk_len,
        hidden_dim=args.critic_hidden_dim,
        prop_dim=8,  # paper: x = (z_rl, s_p)
    ).to(train_device)

    # Target critic (Polyak-averaged copy of twin-Q)
    target_q_critic = copy.deepcopy(q_critic).to(train_device)
    target_q_critic.eval()
    for p in target_q_critic.parameters():
        p.requires_grad_(False)

    # Target actor (Polyak-averaged copy — TD3 uses this for next_action in target Q)
    target_actor = copy.deepcopy(actor).to(train_device)
    target_actor.eval()
    for p in target_actor.parameters():
        p.requires_grad_(False)

    # Also keep a lightweight V(s) critic for rollout (value estimate logging)
    # The rollout only needs a dummy critic for the episode data structure
    dummy_critic = ActionTokenCritic(
        bottleneck_dim=args.bottleneck_dim,
        hidden_dim=64,  # tiny, just for rollout value logging
    ).to(train_device)

    # ── Rollout module copies (one per rollout GPU, tiny ~9M) ──
    rollout_modules = {}
    for gpu_id in rollout_gpu_ids:
        device = f"cuda:{gpu_id}"
        r_enc = copy.deepcopy(enc_dec).to(device).eval()
        r_actor = copy.deepcopy(actor).to(device).eval()
        r_critic = copy.deepcopy(dummy_critic).to(device).eval()
        rollout_modules[gpu_id] = (r_enc, r_actor, r_critic)

    # ── Eval module copies — separate from rollout to allow async eval ──
    # (rollout_modules are read by the rollout thread; eval_modules are read by
    #  the async eval thread. Keeping them separate avoids races.)
    eval_modules = {}
    for gpu_id in rollout_gpu_ids:
        device = f"cuda:{gpu_id}"
        e_enc = copy.deepcopy(enc_dec).to(device).eval()
        e_actor = copy.deepcopy(actor).to(device).eval()
        for p in e_enc.parameters(): p.requires_grad_(False)
        for p in e_actor.parameters(): p.requires_grad_(False)
        eval_modules[gpu_id] = (e_enc, e_actor)

    # ── Per-GPU rollout infrastructure ──
    if args.use_steplock:
        # Step-lock mode: persistent env pools (no BatchInferenceServer needed)
        from AlphaBrain.training.reinforcement_learning.envs.persistent_env_pool import PersistentEnvPool
        from AlphaBrain.training.reinforcement_learning.algos.RLActionToken.action_token_rollout_fast import (
            action_token_collect_group_steplock,
            action_token_collect_multitask_steplock,
        )
        rollout_servers = {}  # not used in steplock mode
        rollout_env_pools = {}
        # Compute tasks per GPU for env pool sizing
        gpu_task_map = {}
        if args.all_tasks:
            task_list_all = args._selected_task_ids if args._selected_task_ids else list(range(n_tasks))
            for t_idx in task_list_all:
                gid = rollout_gpu_ids[t_idx % n_rollout_gpus]
                gpu_task_map.setdefault(gid, []).append(t_idx)
        max_tasks_per_gpu = max(len(v) for v in gpu_task_map.values()) if gpu_task_map else 1
        # Pool sizing: num_envs_per_task × tasks_on_this_gpu
        # Decouples parallelism from G_per_task. If G_per_task > num_envs_per_task,
        # the rollout chunks into ceil(G_per_task / num_envs_per_task) sequential passes.
        envs_per_gpu = args.num_envs_per_task * max_tasks_per_gpu
        # Map logical GPU ID → physical GPU ID for MuJoCo EGL rendering
        cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
        physical_gpus = [int(x) for x in cuda_devices.split(",") if x.strip()] if cuda_devices else list(range(8))
        for gpu_id in rollout_gpu_ids:
            physical_gpu = physical_gpus[gpu_id] if gpu_id < len(physical_gpus) else gpu_id
            pool = PersistentEnvPool(
                num_envs=envs_per_gpu,
                libero_python=os.environ.get("LIBERO_PYTHON"),
                egl_gpu_id=physical_gpu,
            )
            rollout_env_pools[gpu_id] = pool
        n_passes_per_iter = max(1, (args.G_per_task + args.num_envs_per_task - 1) // args.num_envs_per_task)
        logger.info(f"  Step-lock mode: {len(rollout_gpu_ids)} GPU × {envs_per_gpu} persistent envs "
                     f"({max_tasks_per_gpu} tasks/GPU × {args.num_envs_per_task} envs/task, "
                     f"G_per_task={args.G_per_task}{n_passes_per_iter} passes/iter)")

        # No pre-warm needed — parallel reset in rollout handles MuJoCo init.
    else:
        # Async mode: BatchInferenceServer per GPU
        from AlphaBrain.training.reinforcement_learning.algos.RLActionToken.action_token_trainer import BatchInferenceServer
        rollout_servers = {}
        rollout_env_pools = {}  # not used in async mode
        for gpu_id in rollout_gpu_ids:
            r_enc, r_actor, r_critic = rollout_modules[gpu_id]
            server = BatchInferenceServer(
                frozen_vla=vla_copies[gpu_id],
                encoder=r_enc,
                actor=r_actor,
                critic=r_critic,
                device=f"cuda:{gpu_id}",
                max_batch_size=args.num_envs * 4,
                actor_chunk_len=actor_chunk_len if actor_chunk_len != chunk_len else None,
            ).start()
            rollout_servers[gpu_id] = server
        logger.info(f"  Started BatchInferenceServer on GPU {gpu_id} "
                    f"(max_batch={args.num_envs * 4})")

    enc_params = sum(p.numel() for p in enc_dec.parameters())
    actor_params = sum(p.numel() for p in actor.parameters())
    critic_params = sum(p.numel() for p in q_critic.parameters())
    vla_params = sum(p.numel() for p in ref_vla.parameters())
    logger.info(f"  Frozen VLA: {vla_params / 1e9:.2f}B params × {n_rollout_gpus} GPUs")
    logger.info(f"  Frozen encoder: {enc_params / 1e6:.2f}M params")
    logger.info(f"  Trainable: actor={actor_params / 1e6:.2f}M, critic={critic_params / 1e6:.2f}M")
    logger.info(f"  Replay buffer: capacity={args.buffer_capacity}, warmup={args.buffer_warmup}")
    logger.info(f"  TD updates: {args.td_updates_per_iter}/iter, batch={args.td_batch_size}, tau={args.tau}")
    logger.info(f"  Rollout: {n_rollout_gpus} GPUs × {args.num_envs} envs × "
                f"{args.G} episodes/GPU = {n_rollout_gpus * args.G} episodes/iter")

    # ── Separate optimizers for actor and critic (TD3 pattern) ──
    optimizer_critic = torch.optim.AdamW(
        q_critic.parameters(), lr=args.lr_critic,
        betas=(0.9, 0.95), weight_decay=1e-8)
    optimizer_actor = torch.optim.AdamW(
        actor.parameters(), lr=args.lr_actor,
        betas=(0.9, 0.95), weight_decay=1e-8)

    # ── VLA fine-tune optimizer (only when --finetune_vla) ──
    optimizer_vla = None
    if args.finetune_vla:
        train_vla = vla_copies[train_gpu_id]
        vla_params = [p for p in train_vla.parameters() if p.requires_grad]
        enc_params = [p for p in enc_dec.parameters() if p.requires_grad]
        optimizer_vla = torch.optim.AdamW(
            [{"params": vla_params, "lr": args.lr_vla},
             {"params": enc_params, "lr": args.lr_vla * 2}],
            betas=(0.9, 0.95), weight_decay=1e-8)
        n_vla_trainable = sum(p.numel() for p in vla_params)
        n_enc_trainable = sum(p.numel() for p in enc_params)
        logger.info(f"  VLA optimizer: {n_vla_trainable / 1e9:.2f}B VLA + "
                     f"{n_enc_trainable / 1e6:.2f}M encoder, lr={args.lr_vla}")

    # ── Replay buffer (centralized) ──────────────────────
    replay_buffer = ReplayBuffer(capacity=args.buffer_capacity)

    # ── WandB ─────────────────────────────────────────────
    if args.use_wandb:
        run_name = args.run_name or f"action_token_offpolicy_{args.suite}_task{args.task_id}"
        wandb.init(project=args.wandb_project, name=run_name,
                   config={**vars(args), "chunk_len": chunk_len,
                           "hidden_dim": hidden_dim, "action_dim": action_dim,
                           "n_rollout_gpus": n_rollout_gpus, "mode": "offpolicy_td"})

    video_dir = Path(args.output_dir) / "videos"
    metrics_history = []
    best_sr = 0.0
    best_eval_sr = 0.0
    running_sr = []
    total_env_steps = 0  # cumulative environment steps (sample steps)

    # ── Async rollout helper ────────────────────────────
    # Rollout runs on background threads (rollout GPUs) while TD updates
    # run on the main thread (train GPU). This matches the PI paper's
    # asynchronous rollout + learning design.

    buffer_lock = threading.Lock()
    rollout_stats_queue = queue.Queue()   # (episodes, iteration)
    _stop_rollout = threading.Event()
    _weight_sync_lock = threading.Lock()  # protects weight copy (non-blocking)

    def _sync_rollout_weights():
        """Copy latest actor/encoder weights to all rollout GPU copies (non-blocking)."""
        with _weight_sync_lock:
            enc_state_cpu = {k: v.cpu() for k, v in enc_dec.state_dict().items()}
            actor_state_cpu = {k: v.cpu() for k, v in actor.state_dict().items()}
            dummy_critic_state_cpu = {k: v.cpu() for k, v in dummy_critic.state_dict().items()}
            for gpu_id in rollout_gpu_ids:
                r_enc, r_actor, r_critic = rollout_modules[gpu_id]
                dev = f"cuda:{gpu_id}"
                r_enc.load_state_dict({k: v.to(dev) for k, v in enc_state_cpu.items()})
                r_actor.load_state_dict({k: v.to(dev) for k, v in actor_state_cpu.items()})
                r_critic.load_state_dict({k: v.to(dev) for k, v in dummy_critic_state_cpu.items()})
        return enc_state_cpu, actor_state_cpu

    _steplock_warmup = [True]  # shared flag: rollout thread reads, main thread sets False
    _rollout_go = threading.Event()  # clear = paused, set = running
    _rollout_go.set()  # start unpaused

    # ── Async eval state ──
    # Eval runs in a background thread so train + rollout don't block.
    # Results arrive in _eval_results_queue, drained each main iteration.
    _eval_results_queue = queue.Queue()
    _eval_thread_holder = [None]  # mutable holder so closures can update
    _eval_lock = threading.Lock()  # ensures only one eval at a time

    def _sync_eval_weights():
        """Copy latest train weights to eval modules (fast, ~10ms)."""
        with _weight_sync_lock:
            enc_state_cpu = {k: v.cpu() for k, v in enc_dec.state_dict().items()}
            actor_state_cpu = {k: v.cpu() for k, v in actor.state_dict().items()}
            for gpu_id in rollout_gpu_ids:
                e_enc, e_actor = eval_modules[gpu_id]
                dev = f"cuda:{gpu_id}"
                e_enc.load_state_dict({k: v.to(dev) for k, v in enc_state_cpu.items()})
                e_actor.load_state_dict({k: v.to(dev) for k, v in actor_state_cpu.items()})

    def _run_eval_inline(iteration, save_video):
        """Synchronous eval body — same logic as before but uses eval_modules.

        Returns dict with eval_sr, eval_result, per_task_eval_sr.
        Does NOT mutate outer state (best_eval_sr is updated in main thread).
        """
        per_task_eval_sr_local = {}
        eval_result_local = None

        if args.all_tasks:
            # Multi-task eval
            eval_task_list = args._selected_task_ids if args._selected_task_ids else list(range(n_tasks))
            n_eval_tasks = len(eval_task_list)
            eval_n_per_task = max(1, args.eval_n_episodes // n_eval_tasks)
            total_eval_eps = eval_n_per_task * n_eval_tasks
            logger.info(f"[ASYNC EVAL @ iter {iteration}] Multi-task: {eval_n_per_task} eps/task × "
                         f"{n_eval_tasks} tasks = {total_eval_eps} episodes")

            eval_gpu_jobs = {gpu_id: [] for gpu_id in rollout_gpu_ids}
            job_idx = 0
            for task_id_eval in eval_task_list:
                for ep_idx in range(eval_n_per_task):
                    gpu_id = rollout_gpu_ids[job_idx % n_rollout_gpus]
                    eval_gpu_jobs[gpu_id].append((task_id_eval, ep_idx))
                    job_idx += 1

            eval_video_dir = (str(video_dir / f"eval_iter_{iteration:05d}") if save_video else None)

            all_eval_results = []
            with ThreadPoolExecutor(max_workers=n_rollout_gpus * 2) as pool:
                futures = {}
                for gpu_id, jobs in eval_gpu_jobs.items():
                    if not jobs:
                        continue
                    task_groups = defaultdict(list)
                    for tid, eidx in jobs:
                        task_groups[tid].append(eidx)
                    for tid, ep_indices in task_groups.items():
                        e_enc, e_actor = eval_modules[gpu_id]
                        task_vid_dir = (os.path.join(eval_video_dir, f"task_{tid}") if eval_video_dir else None)
                        fut = pool.submit(
                            _eval_deterministic_local,
                            frozen_vla=vla_copies[gpu_id],
                            encoder=e_enc,
                            actor=e_actor,
                            suite_name=args.suite,
                            task_id=tid,
                            action_norm_stats=action_norm_stats,
                            max_steps=max_steps,
                            chunk_len=actor_chunk_len,
                            episode_indices=ep_indices,
                            num_steps_wait=args.num_steps_wait,
                            seed=42,
                            device=f"cuda:{gpu_id}",
                            rank=gpu_id,
                            video_dir=task_vid_dir,
                        )
                        futures[fut] = (gpu_id, tid)
                for fut in as_completed(futures):
                    gpu_id, tid = futures[fut]
                    results = fut.result()
                    for ep_idx, state_idx, success in results:
                        all_eval_results.append((tid, ep_idx, state_idx, success))

            task_successes_map = defaultdict(list)
            for tid, _, _, success in all_eval_results:
                task_successes_map[tid].append(success)
            for tid in sorted(task_successes_map.keys()):
                v = task_successes_map[tid]
                task_sr = float(np.mean(v))
                per_task_eval_sr_local[tid] = task_sr
                logger.info(f"  [async eval] task {tid} ({suite_info['task_names'][tid][:40]}): "
                             f"SR={task_sr:.2%} ({sum(v)}/{len(v)})")

            all_success = [s for _, _, _, s in all_eval_results]
            eval_sr_local = float(np.mean(all_success)) if all_success else 0.0
            eval_result_local = {
                "eval_sr": eval_sr_local,
                "per_task": per_task_eval_sr_local,
                "n_episodes": len(all_success),
            }
        else:
            # Single-task eval
            task_id_eval = args.task_id if args.task_id >= 0 else 0
            n_eval = args.eval_n_episodes
            logger.info(f"[ASYNC EVAL @ iter {iteration}] Single-task: {n_eval} episodes")

            eval_video_dir = str(video_dir / f"eval_iter_{iteration:05d}") if save_video else None
            eval_assignments = {gpu_id: [] for gpu_id in rollout_gpu_ids}
            for ep_idx in range(n_eval):
                gpu_id = rollout_gpu_ids[ep_idx % n_rollout_gpus]
                eval_assignments[gpu_id].append(ep_idx)

            all_eval_results = []
            with ThreadPoolExecutor(max_workers=n_rollout_gpus) as pool:
                futures = {}
                for gpu_id, ep_indices in eval_assignments.items():
                    if not ep_indices:
                        continue
                    e_enc, e_actor = eval_modules[gpu_id]
                    fut = pool.submit(
                        _eval_deterministic_local,
                        frozen_vla=vla_copies[gpu_id],
                        encoder=e_enc,
                        actor=e_actor,
                        suite_name=args.suite,
                        task_id=task_id_eval,
                        action_norm_stats=action_norm_stats,
                        max_steps=max_steps,
                        chunk_len=actor_chunk_len,
                        episode_indices=ep_indices,
                        num_steps_wait=args.num_steps_wait,
                        seed=42,
                        device=f"cuda:{gpu_id}",
                        rank=gpu_id,
                        video_dir=eval_video_dir,
                    )
                    futures[fut] = gpu_id
                for fut in as_completed(futures):
                    all_eval_results.extend(fut.result())

            per_state = defaultdict(list)
            all_success = []
            for ep_idx, state_idx, success in all_eval_results:
                per_state[state_idx].append(success)
                all_success.append(success)

            eval_sr_local = float(np.mean(all_success)) if all_success else 0.0
            per_state_sr = {sid: float(np.mean(v)) for sid, v in sorted(per_state.items())}
            eval_result_local = {
                "eval_sr": eval_sr_local,
                "per_state": per_state_sr,
                "n_episodes": len(all_success),
            }

        return {
            "iteration": iteration,
            "eval_sr": eval_sr_local,
            "eval_result": eval_result_local,
            "per_task_eval_sr": per_task_eval_sr_local,
        }

    def _async_eval_fn(iteration, save_video):
        """Background thread target: run eval, push result to queue."""
        try:
            result = _run_eval_inline(iteration, save_video)
            _eval_results_queue.put(result)
            logger.info(f"[ASYNC EVAL @ iter {iteration}] done, SR={result['eval_sr']:.2%}")
        except Exception:
            logger.exception(f"[ASYNC EVAL @ iter {iteration}] crashed")
        finally:
            _eval_thread_holder[0] = None

    def _rollout_thread_fn(start_iter, max_iter_val):
        """Background thread: continuously collects episodes and pushes to buffer."""
        try:
          for it in range(start_iter, max_iter_val + 1):
            if _stop_rollout.is_set():
                break
            # Wait if paused (during eval)
            _rollout_go.wait()  # blocks until set

            # Build task list
            if args.all_tasks:
                task_list_all = args._selected_task_ids if args._selected_task_ids else list(range(n_tasks))
                gpu_task_assignments = {gpu_id: [] for gpu_id in rollout_gpu_ids}
                for t_idx in task_list_all:
                    gpu_id = rollout_gpu_ids[t_idx % n_rollout_gpus]
                    gpu_task_assignments[gpu_id].append(t_idx)
            else:
                task_id = args.task_id if args.task_id >= 0 else random.randint(0, n_tasks - 1)
                gpu_task_assignments = {gpu_id: [task_id] for gpu_id in rollout_gpu_ids}

            if args.use_steplock:
                # Step-lock: use plain threads (no nested ThreadPoolExecutor).
                # One thread per GPU, each runs action_token_collect_multitask_steplock.
                #
                # Auto-chunk: if G > num_envs, run ceil(G/num_envs) sequential passes per iter.
                # Each pass uses different seeds → different states/noise sampled.
                # This decouples parallelism (num_envs) from total ep/iter (G).
                n_passes = max(1, (args.G_per_task + args.num_envs_per_task - 1) // args.num_envs_per_task)
                G_per_pass = min(args.G_per_task, args.num_envs_per_task)

                all_eps = []
                per_task_sr = {}
                for pass_idx in range(n_passes):
                    gpu_results = {}
                    gpu_threads = []

                    def _run_gpu(gpu_id, task_list, pass_idx=pass_idx):
                        r_enc, r_actor, r_critic = rollout_modules[gpu_id]
                        group_seed = args.seed + it * 1000 + gpu_id * 100 + pass_idx * 50000
                        unique_group_idx = (it * n_passes + pass_idx) * n_rollout_gpus + gpu_id
                        if len(task_list) > 1:
                            eps = action_token_collect_multitask_steplock(
                                env_pool=rollout_env_pools[gpu_id],
                                frozen_vla=vla_copies[gpu_id],
                                encoder=r_enc, actor=r_actor, critic=r_critic,
                                suite_name=args.suite, task_ids=task_list,
                                n_initial_states=50, action_norm_stats=action_norm_stats,
                                max_steps=max_steps, chunk_len=chunk_len,
                                G_per_task=G_per_pass, seed=group_seed,
                                num_steps_wait=args.num_steps_wait,
                                device=f"cuda:{gpu_id}",
                                group_idx=unique_group_idx,
                                store_images=args.finetune_vla,
                                group_size=args.group_size, reward_coef=args.reward_coef,
                                actor_chunk_len=actor_chunk_len if actor_chunk_len != chunk_len else None,
                                warmup_mode=_steplock_warmup[0],
                            )
                        else:
                            eps = action_token_collect_group_steplock(
                                env_pool=rollout_env_pools[gpu_id],
                                frozen_vla=vla_copies[gpu_id],
                                encoder=r_enc, actor=r_actor, critic=r_critic,
                                suite_name=args.suite, task_id=task_list[0],
                                n_initial_states=50, action_norm_stats=action_norm_stats,
                                max_steps=max_steps, chunk_len=chunk_len, G=G_per_pass,
                                seed=group_seed, num_steps_wait=args.num_steps_wait,
                                device=f"cuda:{gpu_id}",
                                group_idx=unique_group_idx,
                                store_images=args.finetune_vla,
                                group_size=args.group_size, reward_coef=args.reward_coef,
                                actor_chunk_len=actor_chunk_len if actor_chunk_len != chunk_len else None,
                                warmup_mode=_steplock_warmup[0],
                            )
                        gpu_results[gpu_id] = (task_list, eps)

                    for gpu_id, task_list in gpu_task_assignments.items():
                        t = threading.Thread(target=_run_gpu, args=(gpu_id, task_list))
                        t.start()
                        gpu_threads.append(t)
                    for t in gpu_threads:
                        t.join()

                    for gpu_id, (tid_list, eps) in gpu_results.items():
                        all_eps.extend(eps)
                        for ep in eps:
                            per_task_sr.setdefault(ep.task_id, []).append(ep.success)
                        n_s = sum(1 for e in eps if e.success)
                        pass_str = f" pass {pass_idx+1}/{n_passes}" if n_passes > 1 else ""
                        logger.info(f"  [rollout iter {it}{pass_str}] GPU {gpu_id} tasks {tid_list}: "
                                    f"{len(eps)} eps, {n_s} success")
            else:
                # Async mode: use ThreadPoolExecutor
                from AlphaBrain.training.reinforcement_learning.algos.RLActionToken.action_token_trainer import action_token_collect_group
                all_eps = []
                per_task_sr = {}
                futs = {}
                with ThreadPoolExecutor(max_workers=n_rollout_gpus * 2) as rollout_pool:
                    for gpu_id, task_list in gpu_task_assignments.items():
                        r_enc, r_actor, r_critic = rollout_modules[gpu_id]
                        for tid in task_list:
                            group_seed = args.seed + it * 1000 + gpu_id * 100 + tid * 10
                            fut = rollout_pool.submit(
                                action_token_collect_group,
                                frozen_vla=vla_copies[gpu_id],
                                encoder=r_enc, actor=r_actor, critic=r_critic,
                                suite_name=args.suite, task_id=tid,
                                n_initial_states=50, action_norm_stats=action_norm_stats,
                                max_steps=max_steps, chunk_len=actor_chunk_len, G=args.G,
                                libero_python=os.environ.get("LIBERO_PYTHON"),
                                seed=group_seed, num_steps_wait=args.num_steps_wait,
                                device=f"cuda:{gpu_id}",
                                num_envs=args.num_envs,
                                group_idx=it * n_tasks * n_rollout_gpus + gpu_id * n_tasks + tid,
                                batch_server=rollout_servers.get(gpu_id),
                                store_images=args.finetune_vla,
                                group_size=args.group_size, reward_coef=args.reward_coef,
                            )
                            futs[fut] = (gpu_id, tid)
                    for fut in as_completed(futs):
                        gpu_id, tid = futs[fut]
                        eps = fut.result()
                        all_eps.extend(eps)
                        n_s = sum(1 for e in eps if e.success)
                        per_task_sr.setdefault(tid, []).extend([e.success for e in eps])
                        logger.info(f"  [rollout iter {it}] GPU {gpu_id} task {tid}: "
                                    f"{len(eps)} eps, {n_s} success")

            if args.all_tasks:
                task_sr_str = " | ".join(
                    f"t{tid}={np.mean(v):.0%}" for tid, v in sorted(per_task_sr.items()))
                logger.info(f"  [rollout iter {it}] Per-task SR: {task_sr_str}")

            with buffer_lock:
                n_pushed = push_episodes_to_buffer(
                    all_eps, replay_buffer, gamma_per_step=args.gamma)

            rollout_stats_queue.put((all_eps, it, n_pushed))
        except Exception as e:
            import traceback
            logger.error(f"!!! Rollout thread CRASHED at iter {it}: {e}")
            logger.error(traceback.format_exc())
            rollout_stats_queue.put(None)

    # ── VLA Warmup (paper Sec. V): pre-fill buffer with pure VLA rollouts ──
    if args.warmup_iters > 0:
        logger.info(f"=== VLA Warmup: {args.warmup_iters} iters of pure VLA rollout ===")
        if not args.use_steplock:
            for gpu_id, server in rollout_servers.items():
                server.warmup_mode = True

    # ── Training loop (async rollout + TD updates) ────
    # Launch rollout in background
    rollout_thread = threading.Thread(
        target=_rollout_thread_fn, args=(1, args.max_iter), daemon=True)
    rollout_thread.start()
    logger.info("Started async rollout thread")

    td_global_step = 0
    last_sync_step = 0
    sync_every_n_updates = 500  # sync weights to rollout every N TD3 updates

    from AlphaBrain.training.reinforcement_learning.algos.RLActionToken.action_token_trainer import (
        action_token_td_actor_update,
        action_token_td_critic_update,
    )

    for iteration in range(1, args.max_iter + 1):
        # ── Drain all available rollout data (non-blocking after first) ────
        all_episodes = []
        # Block on first get (wait for rollout to produce data)
        result = rollout_stats_queue.get()
        if result is None:
            logger.error("Rollout thread crashed (poison pill). Stopping.")
            break
        eps_batch, rollout_iter, n_pushed = result
        all_episodes = list(eps_batch)

        rewards = np.array([ep.reward for ep in all_episodes])
        success_rate = float(np.mean(rewards > 0.5)) if len(rewards) > 0 else 0.0
        iter_env_steps = sum(ep.env_steps for ep in all_episodes)
        total_env_steps += iter_env_steps
        running_sr.append(success_rate)
        if len(running_sr) > 20:
            running_sr.pop(0)
        running_sr_avg = np.mean(running_sr)
        best_sr = max(best_sr, success_rate)

        per_task_rollout_sr = {}
        if args.all_tasks:
            _task_successes = defaultdict(list)
            for ep in all_episodes:
                _task_successes[ep.task_id].append(ep.success)
            per_task_rollout_sr = {tid: float(np.mean(v))
                                   for tid, v in sorted(_task_successes.items())}
            task_sr_str = " | ".join(f"t{tid}={sr:.0%}"
                                     for tid, sr in per_task_rollout_sr.items())
        else:
            task_sr_str = ""

        logger.info(f"{'='*60}")
        logger.info(f"[iter {iteration}/{args.max_iter}] Got {len(all_episodes)} episodes "
                     f"(rollout batch {rollout_iter}) | SR={success_rate:.2f} "
                     f"(best={best_sr:.2f}, avg={running_sr_avg:.2f}) "
                     f"| buffer={len(replay_buffer)}/{args.buffer_capacity} "
                     f"| total_env_steps={total_env_steps} | td_steps={td_global_step}")
        if task_sr_str:
            logger.info(f"  Per-task rollout SR: {task_sr_str}")

        # ── VLA warmup phase ──
        td_stats_list = []  # empty during warmup; filled during TD3
        if iteration <= args.warmup_iters:
            logger.info(f"[iter {iteration}] VLA warmup ({iteration}/{args.warmup_iters}), "
                         f"buffer={len(replay_buffer)} — skipping TD updates")
            # Don't `continue` — fall through to logging + wandb so metrics are tracked
        elif iteration == args.warmup_iters + 1:
            if args.use_steplock:
                _steplock_warmup[0] = False
            else:
                for gpu_id, server in rollout_servers.items():
                    server.warmup_mode = False
            _sync_rollout_weights()
            logger.info(f"=== VLA warmup done. Buffer pre-filled with {len(replay_buffer)} "
                         f"transitions. Starting TD3 training. ===")

        # ── Async TD3 updates: run UTD×new_data updates per new data batch ──
        # Paper Algorithm 1: TD updates run EVERY step (including warmup),
        # warmup only controls which action is used for rollout (VLA vs actor).
        if replay_buffer.is_ready(min_size=args.buffer_warmup):
            actor.train()
            q_critic.train()

            n_tasks_for_balance = len(args._selected_task_ids) if args._selected_task_ids else (n_tasks if args.all_tasks else 0)
            batch_sz = min(args.td_batch_size, len(replay_buffer))

            # UTD-based: n_updates = new_transitions × utd_ratio / batch_size
            n_new_transitions = n_pushed
            n_updates = max(1, int(n_new_transitions * args.utd_ratio / batch_sz))
            n_updates = min(n_updates, args.td_updates_per_iter)  # cap

            td_stats_list = []
            for td_step in range(n_updates):
                optimizer_critic.zero_grad()
                critic_loss, c_stats = action_token_td_critic_update(
                    actor=actor,
                    q_critic=q_critic,
                    target_q_critic=target_q_critic,
                    replay_buffer=replay_buffer,
                    batch_size=batch_sz,
                    gamma=args.gamma ** actor_chunk_len,
                    device=train_device,
                    target_noise_std=args.target_noise_std,
                    target_noise_clip=args.target_noise_clip,
                    n_tasks=n_tasks_for_balance,
                    target_actor=target_actor,
                )
                critic_loss.backward()
                if args.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(q_critic.parameters(), args.max_grad_norm)
                optimizer_critic.step()

                a_stats = {"actor_loss": 0.0, "q_actor_mean": 0.0, "bc_penalty": 0.0}
                if (td_step + 1) % args.actor_update_freq == 0:
                    optimizer_actor.zero_grad()
                    actor_loss, a_stats = action_token_td_actor_update(
                        actor=actor,
                        q_critic=q_critic,
                        replay_buffer=replay_buffer,
                        batch_size=batch_sz,
                        beta=args.beta,
                        device=train_device,
                        n_tasks=n_tasks_for_balance,
                    )
                    actor_loss.backward()
                    if args.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(actor.parameters(), args.max_grad_norm)
                    optimizer_actor.step()
                    soft_update_target(q_critic, target_q_critic, tau=args.tau)
                    soft_update_target(actor, target_actor, tau=args.tau)

                td_stats_list.append({**c_stats, **a_stats,
                                      "td_loss": c_stats["critic_loss"] + a_stats["actor_loss"]})
                td_global_step += 1

            avg_td = np.mean([s["td_loss"] for s in td_stats_list])
            avg_critic = np.mean([s["critic_loss"] for s in td_stats_list])
            avg_actor = np.mean([s["actor_loss"] for s in td_stats_list])
            avg_bc = np.mean([s.get("bc_penalty", 0.0) for s in td_stats_list])
            avg_q = np.mean([s.get("q1_mean", 0.0) for s in td_stats_list])
            logger.info(f"[iter {iteration}] TD3: {n_updates} updates (UTD={n_new_transitions}×{args.utd_ratio}/{batch_sz}{n_updates}) "
                         f"critic={avg_critic:.4f} actor={avg_actor:.4f} "
                         f"bc={avg_bc:.4f} q_mean={avg_q:.4f}")

            # Sync weights to rollout periodically
            if td_global_step - last_sync_step >= sync_every_n_updates:
                _sync_rollout_weights()
                last_sync_step = td_global_step
                logger.info(f"  [sync] Weights synced to rollout (td_step={td_global_step})")

            # ── VLA fine-tune step ──
            if (args.finetune_vla and optimizer_vla is not None
                    and iteration % args.vla_update_freq == 0):
                from AlphaBrain.training.reinforcement_learning.algos.RLActionToken.action_token_trainer import vla_finetune_step
                train_vla = vla_copies[train_gpu_id]
                train_vla.train()
                optimizer_vla.zero_grad()
                try:
                    vla_stats = vla_finetune_step(
                        vla=train_vla, encoder=enc_dec, actor=actor,
                        q_critic=q_critic, episodes=all_episodes,
                        beta=args.beta, device=train_device,
                        micro_batch=args.vla_micro_batch)
                    if args.max_grad_norm > 0:
                        all_vla_params = list(train_vla.parameters()) + list(enc_dec.parameters())
                        torch.nn.utils.clip_grad_norm_(all_vla_params, args.max_grad_norm)
                    optimizer_vla.step()
                    logger.info(f"[iter {iteration}] VLA fine-tune: loss={vla_stats.get('vla_loss', 0):.4f}")
                except torch.cuda.OutOfMemoryError:
                    logger.warning(f"[iter {iteration}] VLA fine-tune OOM — skipping")
                    optimizer_vla.zero_grad(set_to_none=True)
                    torch.cuda.empty_cache()
                finally:
                    train_vla.eval()
                    for ep in all_episodes:
                        for sr in ep.step_records:
                            sr.images = None; sr.instruction = None
                # Sync VLA weights
                train_vla = vla_copies[train_gpu_id]
                vla_state_cpu = {k: v.cpu() for k, v in train_vla.state_dict().items()}
                for gpu_id in rollout_gpu_ids:
                    if gpu_id != train_gpu_id:
                        vla_copies[gpu_id].load_state_dict(
                            {k: v.to(f"cuda:{gpu_id}") for k, v in vla_state_cpu.items()})
        else:
            logger.info(f"[iter {iteration}] Buffer warmup: {len(replay_buffer)}/{args.buffer_warmup} "
                         f"(waiting for more data)")

        # ── Async eval ──
        # Eval runs in a background thread; rollout + train do not block.
        # Results arrive in _eval_results_queue and are drained below.
        eval_sr = None
        eval_result = None
        per_task_eval_sr = {}
        do_eval = (args.eval_interval > 0
                   and (iteration == 1 or iteration % args.eval_interval == 0))
        if do_eval:
            save_video = (args.save_video_interval > 0 and
                          (iteration == 1 or iteration % args.save_video_interval == 0))
            with _eval_lock:
                prev = _eval_thread_holder[0]
                if prev is None or not prev.is_alive():
                    # Sync latest train weights to eval modules (~10ms)
                    _sync_eval_weights()
                    # Spawn async eval thread (non-blocking)
                    t = threading.Thread(
                        target=_async_eval_fn,
                        args=(iteration, save_video),
                        daemon=True,
                        name=f"async_eval_{iteration}",
                    )
                    _eval_thread_holder[0] = t
                    t.start()
                    logger.info(f"[iter {iteration}] Spawned async eval (rollout/train continue)")
                else:
                    logger.warning(f"[iter {iteration}] Skip eval — previous async eval still running")

        # ── Drain async eval results (non-blocking, every iter) ──
        while True:
            try:
                eval_data = _eval_results_queue.get_nowait()
            except queue.Empty:
                break
            from_iter = eval_data["iteration"]
            eval_sr = eval_data["eval_sr"]
            eval_result = eval_data["eval_result"]
            per_task_eval_sr = eval_data["per_task_eval_sr"]
            if eval_sr > best_eval_sr:
                best_eval_sr = eval_sr
            logger.info(f"[ASYNC RESULT] from iter {from_iter}: "
                         f"SR={eval_sr:.2%} (best_eval={best_eval_sr:.2%})")
            if eval_result and "per_state" in eval_result:
                for sid, sr in eval_result["per_state"].items():
                    logger.info(f"    state {sid}: {sr:.2%}")

        # ── Logging ───────────────────────────────────
        try:
            log_entry = {
                "iter": iteration,
                "total_env_steps": total_env_steps,
                "iter_env_steps": iter_env_steps,
                "success_rate": success_rate,
                "best_success_rate": best_sr,
                "running_avg_sr": running_sr_avg,
                "mean_reward": float(np.mean(rewards)) if len(rewards) > 0 else 0.0,
                "buffer_size": len(replay_buffer),
                "n_pushed": n_pushed,
            }
            if td_stats_list:
                avg_fn = lambda k: float(np.mean([s[k] for s in td_stats_list if k in s]))
                log_entry.update({
                    "td_loss": avg_fn("td_loss"),
                    "actor_loss": avg_fn("actor_loss"),
                    "critic_loss": avg_fn("critic_loss"),
                    "q1_mean": avg_fn("q1_mean"),
                    "q2_mean": avg_fn("q2_mean"),
                    "target_mean": avg_fn("target_mean"),
                    "bc_penalty": avg_fn("bc_penalty"),
                    "q_actor_mean": avg_fn("q_actor_mean"),
                })
            if eval_sr is not None:
                log_entry["eval_sr"] = eval_sr
                log_entry["best_eval_sr"] = best_eval_sr
            if per_task_rollout_sr:
                log_entry["per_task_rollout_sr"] = per_task_rollout_sr
            if per_task_eval_sr:
                log_entry["per_task_eval_sr"] = per_task_eval_sr
            metrics_history.append(log_entry)

            if args.use_wandb:
                wandb_log = {
                    "rollout/success_rate": success_rate,
                    "rollout/best_success_rate": best_sr,
                    "rollout/running_avg_sr": running_sr_avg,
                    "rollout/mean_reward": log_entry["mean_reward"],
                    "rollout/total_env_steps": total_env_steps,
                    "rollout/iter_env_steps": iter_env_steps,
                    "buffer/size": len(replay_buffer),
                    "buffer/pushed": n_pushed,
                }
                # Per-task rollout SR
                for tid, sr in per_task_rollout_sr.items():
                    wandb_log[f"rollout/task_{tid:02d}_sr"] = sr
                if td_stats_list:
                    wandb_log.update({
                        "train/td_loss": log_entry["td_loss"],
                        "train/actor_loss": log_entry["actor_loss"],
                        "train/critic_loss": log_entry["critic_loss"],
                        "train/q1_mean": log_entry["q1_mean"],
                        "train/q2_mean": log_entry["q2_mean"],
                        "train/target_mean": log_entry["target_mean"],
                        "train/bc_penalty": log_entry["bc_penalty"],
                        "train/q_actor_mean": log_entry["q_actor_mean"],
                        "train/actor_lr": optimizer_actor.param_groups[0]["lr"],
                        "train/n_updates": n_updates if td_stats_list else 0,
                    })
                if eval_sr is not None:
                    wandb_log["eval/success_rate"] = eval_sr
                    wandb_log["eval/best_success_rate"] = best_eval_sr
                    # Per-task eval SR
                    for tid, sr in per_task_eval_sr.items():
                        wandb_log[f"eval/task_{tid:02d}_sr"] = sr
                    if eval_result and "per_state" in eval_result:
                        for sid, sr in eval_result["per_state"].items():
                            wandb_log[f"eval/state_{sid:02d}"] = sr
                for ep in sorted(all_episodes, key=lambda e: -e.success):
                    if ep.video_path and os.path.exists(ep.video_path):
                        status = "success" if ep.success else "fail"
                        wandb_log[f"video/{status}"] = wandb.Video(
                            ep.video_path, fps=10, format="mp4")
                        break
                wandb.log(wandb_log, step=iteration)
                logger.info(f"[iter {iteration}] wandb.log OK (step={iteration})")
        except Exception as _log_err:
            logger.error(f"[iter {iteration}] LOGGING BLOCK EXCEPTION: {_log_err}")
            import traceback; traceback.print_exc()

        # ── 7. Checkpoint ────────────────────────────────
        if iteration % args.save_interval == 0:
            save_rlt_checkpoint(enc_dec, actor, q_critic,
                                iteration, args.output_dir, phase="rl_offpolicy")

    # Stop rollout thread
    _stop_rollout.set()
    rollout_thread.join(timeout=10)
    logger.info("Rollout thread stopped")

    # Wait for any pending async eval to finish + drain final results
    last_eval = _eval_thread_holder[0]
    if last_eval is not None and last_eval.is_alive():
        logger.info("Waiting for final async eval to finish (max 600s)...")
        last_eval.join(timeout=600)
    while True:
        try:
            eval_data = _eval_results_queue.get_nowait()
        except queue.Empty:
            break
        from_iter = eval_data["iteration"]
        eval_sr_final = eval_data["eval_sr"]
        if eval_sr_final > best_eval_sr:
            best_eval_sr = eval_sr_final
        logger.info(f"[ASYNC RESULT @ shutdown] from iter {from_iter}: "
                     f"SR={eval_sr_final:.2%} (best_eval={best_eval_sr:.2%})")

    # Stop rollout infrastructure
    if args.use_steplock:
        for gpu_id, pool in rollout_env_pools.items():
            pool.close()
            logger.info(f"  Closed PersistentEnvPool on GPU {gpu_id}")
    else:
        for gpu_id, server in rollout_servers.items():
            server.stop()
            logger.info(f"  Stopped BatchInferenceServer on GPU {gpu_id}")

    # Final save
    save_rlt_checkpoint(enc_dec, actor, q_critic,
                        args.max_iter, args.output_dir, phase="rl_offpolicy")
    metrics_path = Path(args.output_dir) / "metrics.json"
    metrics_path.parent.mkdir(parents=True, exist_ok=True)
    with open(metrics_path, "w") as f:
        json.dump(metrics_history, f, indent=2)
    logger.info(f"Done. Metrics -> {metrics_path}")

    if args.use_wandb:
        wandb.finish()