PR #1330

open

[Non-record]: JEPA v2 —JEPA v2 — Why same-sequence next-k JEPA collapses in causal LMs

by luciobaiocchiView on GitHub
val_bpb
1.4617
Architecture
Transformer
Optimizer
Artifact Size
8.70 MB

Training Techniques

Architecture
U-Net skip connections
11-layer U-Net Transformer with encoder-decoder skip connections.
parameters: {"layers":11,"dim":512,"encoder_layers":5,"decoder_layers":6}
GQA
Grouped query attention with fewer KV heads than query heads.
parameters: {"query_heads":8,"kv_heads":4}
BigramHash
Adds hashed bigram embeddings to token embeddings before the transformer.
parameters: {"bigram_vocab_size":2048,"dim":512}
LeakyReLU
Uses LeakyReLU(0.5)^2 in the MLP instead of ReLU^2.
parameters: {"negative_slope":0.5}
Weight Averaging
EMA
parameters: {"decay":0.9999}
EMA
parameters: {"decay":0.9}
Quantization
int6
bits: 6
scope: all
Compression
lzma
level: 9
Regularization
logit softcap
parameters: {"value":30}
Sequence Length
sequence_length
train_length: 1024
eval_length: null

Novel Contributions

  • Diagnoses why single-step same-sequence JEPA collapses in causal language models.
  • Introduces a multi-step JEPA objective with offsets [1, 2, 4, 8] and weighted losses.
  • Identifies and fixes a gradient accumulation bug where JEPA targets were cached from the wrong micro-batch.
  • Adds BigramHash embeddings to inject explicit bigram statistics into the model.
  • Uses int6 quantization and LZMA compression to reduce artifact size.
  • Uses artifact EMA to smooth the final checkpoint.