PR #896

open

[Non-Record] JEPA Self-Distillation with EMA Target Encoder for Autoregressive LM (val_bpb: 1.19) | Current Noisy/Negative Result

by MVPandeyView on GitHub
val_bpb
1.1896
Architecture
Transformer
Optimizer
Muon
Artifact Size
16.3MB

Training Techniques

Weight Averaging
EMA
parameters: {"decay":0.9995}
Architecture
GQA
Grouped query attention with fewer KV heads than attention heads
parameters: {"heads":8,"kv_heads":4}
Partial RoPE
Partial rotary positional embeddings
parameters: null
XSA
XSA used in the last 4 layers
parameters: {"layers":4}
U-Net skip connections
Skip connections in a U-Net style
parameters: null
BigramHash
Bigram hash embedding component
parameters: null
LeakyReLU
LeakyReLU activation with squared variant mentioned in the backbone
parameters: {"slope":0.5}
MLP3x
Three-times expanded MLP
parameters: {"multiplier":3}
QK RMSNorm
RMSNorm applied to QK projections
parameters: null
predictor MLP
Training-only predictor network for JEPA latent prediction
parameters: {"layers":2,"dimensions":256}
projection heads
Context and target projection heads for JEPA
parameters: {"input_dim":512,"output_dim":256}
Regularization
logit softcap
parameters: {"cap":30}
VICReg
parameters: null
weight decay
parameters: null
LR Schedule
linear warmup
parameters: {"warmup_steps":200}

Novel Contributions

  • JEPA self-distillation for autoregressive language modeling using an EMA target encoder
  • Controlled A/B comparison against a vanilla cross-entropy baseline under matched seed, hardware, and wall-clock budget
  • Identification that JEPA latent prediction provides little benefit over next-token cross-entropy for BPE token prediction
  • Training-only JEPA auxiliary components with zero inference cost in the saved artifact
  • Empirical analysis of VICReg placement, EMA decay, JEPA loss weighting, and warmup stability
  • Discovery and correction of a quantization bug affecting artifact size reporting