PR #1312
openNon-record: JEPA-LM Latent Predictive World Model (first JEPA submission)
by adi-suresh01View on GitHub
val_bpb
1.3299
Architecture
Transformer
Optimizer
—
Artifact Size
15.3 MB
Training Techniques
Architecture
MLP3x
Uses a 3x MLP expansion in the transformer blocks.
parameters: {"multiplier":3}
LeakyReLU
Uses LeakyReLU(0.5) squared activation.
parameters: {"negative_slope":0.5}
GQA
Grouped query attention with fewer KV heads than query heads.
parameters: {"query_heads":8,"kv_heads":4}
U-Net skip connections
Encoder-decoder style skip connections with learned skip weights.
parameters: null
weight tying
Tied input and output embeddings.
parameters: null
RoPE
Uses rotary positional embeddings.
parameters: null
LatentPredictor
Training-only bottleneck predictor at the encoder-decoder boundary for latent prediction.
parameters: {"input_dim":512,"hidden_dim":128,"output_dim":512}
depth recurrence
Recursive multi-horizon latent rollout predicts future representations at t+1, t+2, and t+3.
parameters: {"horizons":3}
Regularization
weight decay
parameters: null
LR Schedule
warmdown
parameters: null
Novel Contributions
- First JEPA implementation for Parameter Golf
- Multi-horizon latent prediction auxiliary loss at the encoder-decoder boundary
- Predictor stripped at export for zero inference artifact overhead
- Positive ablation showing a small BPB improvement over the same architecture without JEPA
- Recursive latent rollout world-model style training objective