PR #1312

open

Non-record: JEPA-LM Latent Predictive World Model (first JEPA submission)

by adi-suresh01View on GitHub
val_bpb
1.3299
Architecture
Transformer
Optimizer
Artifact Size
15.3 MB

Training Techniques

Architecture
MLP3x
Uses a 3x MLP expansion in the transformer blocks.
parameters: {"multiplier":3}
LeakyReLU
Uses LeakyReLU(0.5) squared activation.
parameters: {"negative_slope":0.5}
GQA
Grouped query attention with fewer KV heads than query heads.
parameters: {"query_heads":8,"kv_heads":4}
U-Net skip connections
Encoder-decoder style skip connections with learned skip weights.
parameters: null
weight tying
Tied input and output embeddings.
parameters: null
RoPE
Uses rotary positional embeddings.
parameters: null
LatentPredictor
Training-only bottleneck predictor at the encoder-decoder boundary for latent prediction.
parameters: {"input_dim":512,"hidden_dim":128,"output_dim":512}
depth recurrence
Recursive multi-horizon latent rollout predicts future representations at t+1, t+2, and t+3.
parameters: {"horizons":3}
Regularization
weight decay
parameters: null
LR Schedule
warmdown
parameters: null

Novel Contributions

  • First JEPA implementation for Parameter Golf
  • Multi-horizon latent prediction auxiliary loss at the encoder-decoder boundary
  • Predictor stripped at export for zero inference artifact overhead
  • Positive ablation showing a small BPB improvement over the same architecture without JEPA
  • Recursive latent rollout world-model style training objective