Problem 3: The Stripped Transformer¶
Task Description: A transformer has reached us in pieces. Someone tore out its front and its head: the embedding table, the very first layer, and the output head that turns hidden states back into words are all gone. What is left is the long middle of the network. The strange part is that this wreckage still knows things, and our job is to coax the answers back out of it.
Concretely, the model is a stripped Qwen3.5-2B. Only layers 1 to 23 and the final normalization survived; embed_tokens, layers[0], and lm_head are missing.
The model is being asked multiple-choice questions. Each question has a context, a query, and 4 answer options (plus 2 "sink" options that are never correct and never asked about). For every question you are handed the hidden state right after layer 0 of the original full model, so you can pick the forward pass back up from layer 1 without ever needing the embedding table.
Your goal: for each of the 1000 test questions, predict the index of the correct option, an integer in {0, 1, 2, 3}. You are scored on accuracy, and guessing blindly gets you about 25%.
Restrictions:
- No other models, no outside data. Work only with what survived and what you are given.
- Training is allowed, but you only get 100 labeled questions, so the real enemy is overfitting. Whatever you train, keep it small.
What you're given¶
download_corrupted_qwen.sh pulls everything into this folder:
model.ptandmodel_config/the surviving weights and the config to rebuild the architecture.student_train.pkl100 questions with their answers, for training and for checking yourself.student_val.pkl1000 questions without answers. This is the test set.sample_submission.csvthe format your submission has to follow.
The numbers live in pickled lists of dictionaries. One question is one dict:
| field | what it is |
|---|---|
idx |
the question's id (matches the submission) |
embeds_after_l0 |
a (seq_len, 2048) array, the hidden state right after layer 0 of the full model |
eol_positions |
the position of each option's last token, one per option |
correct_idx |
the index of the correct option (only in student_train.pkl) |
Step 1: (R)ead the problem¶
One of the first steps in R.I.C.E. is to read the task and name its shape. Back in Problem 1 we had a rule of thumb: text categorization is the home turf of encoders like BERT, while language modelling (predicting the next word) is what decoders like GPT are built for. This time we are handed a decoder, a Qwen language model, with a multiple-choice question to answer.
So what is the shape here? On paper it is plain multi-class classification: pick 1 of 4 options. But the interesting part is where the answer comes from. We are not training a model to understand the question from scratch. The model already understands it. The catch is that the piece that would normally say the answer, the output head, has been torn off. So this is really a read-the-answer-out-of-a-model problem.
- The goal: for each question, decide which of the 4 options is correct.
- The problem: a decoder normally answers by turning its final hidden state into a probability over the vocabulary, through the
lm_head. That head is gone, so we cannot just read off what it would have said. What survived is the stack of middle layers, and those layers still push the correct option and the wrong ones into different-looking hidden states. - Why
student_train.pklis useful: with 100 labeled questions we can train a tiny classifier to read "is this the correct option?" straight out of those hidden states, and we can score ourselves before we ever touch the test set.
The plan, in plain words¶
Here is the picture. A decoder reads left to right, one token at a time, and at every position it keeps a running vector that has soaked up everything to its left. So by the time it reaches the last token of an answer option, that one vector has absorbed the context, the query, and the option itself. That position is the model's little summary of "given all of this, how good is this option?"
Normally the model would push that vector through the output head to guess the next word. We do not have the head, so we do the reading ourselves: run the surviving layers, grab the vector sitting at each option's last token, and train a small classifier to tell the correct option's vector apart from the wrong ones. A small classifier trained on a frozen model's hidden states, just to see what they encode, is called a probe.
Two things let us get away with this even though the front and the head are missing:
- We are handed the hidden state after layer 0, so we never need the embedding table to get started. We drop that vector in and let layers 1 onward run.
- We read hidden states, not words, so we never need the output head either.
Let's actually look at the data.
import os, pickle, numpy as np
HERE = "."
TRAIN = os.path.join(HERE, "student_train.pkl")
VAL = os.path.join(HERE, "student_val.pkl")
train = pickle.load(open(TRAIN, "rb"))
val = pickle.load(open(VAL, "rb"))
print(
f"Train: {len(train)} questions (labeled) Val: {len(val)} questions (unlabeled)"
)
# What does one question look like?
print("\nFields of one question:")
for key, value in train[0].items():
shape = getattr(value, "shape", None)
print(
f"\t{key:<16} {type(value).__name__:<8} shape={shape} dtype={getattr(value, 'dtype', None)}"
)
print(
f"\nFirst question:"
f"\n\tOption token positions: {train[0]['eol_positions']}"
f"\n\tCorrect option index: {train[0]['correct_idx']}"
)
Train: 100 questions (labeled) Val: 1000 questions (unlabeled) Fields of one question: idx int shape=None dtype=None embeds_after_l0 ndarray shape=(227, 2048) dtype=float16 eol_positions list shape=None dtype=None correct_idx int shape=None dtype=None First question: Option token positions: [164, 183, 199, 208] Correct option index: 0
Rebuilding the stripped model¶
Same two-step move as Problem 2: read the config to build the empty architecture, then pour the saved weights in. The new wrinkle is the three holes. We plug each one with a no-op so the forward pass runs from end to end:
embed_tokensturns token ids into vectors. We already have vectors, so we swap it fornn.Identityand never call it.layers[0]is the first block. Our input is the state after layer 0, so running layer 0 again would apply it twice. We swap it for a pass-through.lm_headturns the final hidden state into word scores. We want the hidden states themselves, so we swap it fornn.Identitytoo.
The one thing to watch is load_state_dict(..., strict=False). The checkpoint is missing those three pieces, and strict=False is what lets the rest of the weights load anyway instead of erroring on the gaps.
import torch
from torch import nn
from transformers import Qwen3_5TextConfig, Qwen3_5ForCausalLM
# Let unsupported ops fall back to CPU instead of erroring on Apple Silicon (MPS)
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
# Pick the best available device: CUDA, then Apple's MPS, then CPU
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
print("Using device:", device)
class _Pass(nn.Module):
"""A transformer block replacement that returns its input unchanged."""
def forward(self, hidden_states, *args, **kwargs):
return hidden_states
# Build the architecture from the config, then load the surviving weights
config = Qwen3_5TextConfig.from_pretrained(os.path.join(HERE, "model_config"))
model = Qwen3_5ForCausalLM(config).to(dtype=torch.float16, device=device)
model.load_state_dict(
torch.load(os.path.join(HERE, "model.pt"), weights_only=True),
strict=False, # the checkpoint is missing the three stripped parts
)
# Fill the three holes with no-ops so the forward pass runs
model.model.embed_tokens = nn.Identity() # inputs are already embedded
model.model.layers[0] = _Pass() # we start from the post-layer-0 state
model.lm_head = nn.Identity() # we read hidden states, not word scores
model.eval()
/Users/ljvmiranda/Developer/ioaiph26-nlp/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Using device: mps
[transformers] The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d
Qwen3_5ForCausalLM(
(model): Qwen3_5TextModel(
(embed_tokens): Identity()
(layers): ModuleList(
(0): _Pass()
(1-2): 2 x Qwen3_5DecoderLayer(
(linear_attn): Qwen3_5GatedDeltaNet(
(act): SiLUActivation()
(conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)
(norm): Qwen3_5RMSNormGated()
(out_proj): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_qkv): Linear(in_features=2048, out_features=6144, bias=False)
(in_proj_z): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_b): Linear(in_features=2048, out_features=16, bias=False)
(in_proj_a): Linear(in_features=2048, out_features=16, bias=False)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(3): Qwen3_5DecoderLayer(
(self_attn): Qwen3_5Attention(
(q_proj): Linear(in_features=2048, out_features=4096, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(q_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
(k_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(4-6): 3 x Qwen3_5DecoderLayer(
(linear_attn): Qwen3_5GatedDeltaNet(
(act): SiLUActivation()
(conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)
(norm): Qwen3_5RMSNormGated()
(out_proj): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_qkv): Linear(in_features=2048, out_features=6144, bias=False)
(in_proj_z): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_b): Linear(in_features=2048, out_features=16, bias=False)
(in_proj_a): Linear(in_features=2048, out_features=16, bias=False)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(7): Qwen3_5DecoderLayer(
(self_attn): Qwen3_5Attention(
(q_proj): Linear(in_features=2048, out_features=4096, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(q_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
(k_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(8-10): 3 x Qwen3_5DecoderLayer(
(linear_attn): Qwen3_5GatedDeltaNet(
(act): SiLUActivation()
(conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)
(norm): Qwen3_5RMSNormGated()
(out_proj): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_qkv): Linear(in_features=2048, out_features=6144, bias=False)
(in_proj_z): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_b): Linear(in_features=2048, out_features=16, bias=False)
(in_proj_a): Linear(in_features=2048, out_features=16, bias=False)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(11): Qwen3_5DecoderLayer(
(self_attn): Qwen3_5Attention(
(q_proj): Linear(in_features=2048, out_features=4096, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(q_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
(k_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(12-14): 3 x Qwen3_5DecoderLayer(
(linear_attn): Qwen3_5GatedDeltaNet(
(act): SiLUActivation()
(conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)
(norm): Qwen3_5RMSNormGated()
(out_proj): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_qkv): Linear(in_features=2048, out_features=6144, bias=False)
(in_proj_z): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_b): Linear(in_features=2048, out_features=16, bias=False)
(in_proj_a): Linear(in_features=2048, out_features=16, bias=False)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(15): Qwen3_5DecoderLayer(
(self_attn): Qwen3_5Attention(
(q_proj): Linear(in_features=2048, out_features=4096, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(q_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
(k_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(16-18): 3 x Qwen3_5DecoderLayer(
(linear_attn): Qwen3_5GatedDeltaNet(
(act): SiLUActivation()
(conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)
(norm): Qwen3_5RMSNormGated()
(out_proj): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_qkv): Linear(in_features=2048, out_features=6144, bias=False)
(in_proj_z): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_b): Linear(in_features=2048, out_features=16, bias=False)
(in_proj_a): Linear(in_features=2048, out_features=16, bias=False)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(19): Qwen3_5DecoderLayer(
(self_attn): Qwen3_5Attention(
(q_proj): Linear(in_features=2048, out_features=4096, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(q_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
(k_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(20-22): 3 x Qwen3_5DecoderLayer(
(linear_attn): Qwen3_5GatedDeltaNet(
(act): SiLUActivation()
(conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144, bias=False)
(norm): Qwen3_5RMSNormGated()
(out_proj): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_qkv): Linear(in_features=2048, out_features=6144, bias=False)
(in_proj_z): Linear(in_features=2048, out_features=2048, bias=False)
(in_proj_b): Linear(in_features=2048, out_features=16, bias=False)
(in_proj_a): Linear(in_features=2048, out_features=16, bias=False)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
(23): Qwen3_5DecoderLayer(
(self_attn): Qwen3_5Attention(
(q_proj): Linear(in_features=2048, out_features=4096, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(q_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
(k_norm): Qwen3_5RMSNorm((256,), eps=1e-06)
)
(mlp): Qwen3_5MLP(
(gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
(up_proj): Linear(in_features=2048, out_features=6144, bias=False)
(down_proj): Linear(in_features=6144, out_features=2048, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen3_5RMSNorm((2048,), eps=1e-06)
)
)
(norm): Qwen3_5RMSNorm((2048,), eps=1e-06)
(rotary_emb): Qwen3_5TextRotaryEmbedding()
)
(lm_head): Identity()
)
What one block looks like¶
We keep saying "the surviving layers," so let's open one up. If you print a single block you will notice this is not the textbook decoder. Most of its blocks use a cheap linear attention mixer, and only every 4th block uses the classic full attention. The config spells this out in layer_types. Both kinds share the same outer shape, mix information across tokens, then a gated feed-forward, with RMSNorm and residual adds around each. They just mix tokens differently:
- A linear-attention block carries a
linear_attnmixer: the projections (in_proj_*), a short causalconv1d, and a couple of decay parameters (A_log,dt_bias). It summarizes the past with a running state instead of comparing every pair of tokens, which is what makes it cheap. - A full-attention block carries the familiar
self_attn:q_proj,k_proj,v_proj,o_proj, with far fewer key and value heads than query heads (grouped-query attention), plus a per-headq_normandk_norm.
Both kinds are causal, information only flows left to right, which is why the last token of an option is the place to read.
# Peek at two real blocks: a linear-attention one and a full-attention one
for layer_idx in [1, 3]:
print(f"\nLayer {layer_idx} ({config.layer_types[layer_idx]})")
for name, param in model.model.layers[layer_idx].named_parameters():
print(f"\t{name:32s} {tuple(param.shape)}")
Layer 1 (linear_attention) linear_attn.dt_bias (16,) linear_attn.A_log (16,) linear_attn.conv1d.weight (6144, 1, 4) linear_attn.norm.weight (128,) linear_attn.out_proj.weight (2048, 2048) linear_attn.in_proj_qkv.weight (6144, 2048) linear_attn.in_proj_z.weight (2048, 2048) linear_attn.in_proj_b.weight (16, 2048) linear_attn.in_proj_a.weight (16, 2048) mlp.gate_proj.weight (6144, 2048) mlp.up_proj.weight (6144, 2048) mlp.down_proj.weight (2048, 6144) input_layernorm.weight (2048,) post_attention_layernorm.weight (2048,) Layer 3 (full_attention) self_attn.q_proj.weight (4096, 2048) self_attn.k_proj.weight (512, 2048) self_attn.v_proj.weight (512, 2048) self_attn.o_proj.weight (2048, 2048) self_attn.q_norm.weight (256,) self_attn.k_norm.weight (256,) mlp.gate_proj.weight (6144, 2048) mlp.up_proj.weight (6144, 2048) mlp.down_proj.weight (2048, 6144) input_layernorm.weight (2048,) post_attention_layernorm.weight (2048,)
First, just pass one question through¶
Before we build the baseline properly, let's push a single question through the model and see what falls out. Here is roughly what the model was originally given (we don't get the text, only the numbers, but it helps to picture it):
Context: The Nile is the longest river in Africa, flowing north through eleven countries before reaching the sea.
Question: Which sea does the Nile empty into?
A. The Mediterranean Sea
B. The Red Sea
C. The Atlantic Ocean
D. Lake Victoria
E. <sink>
F. <sink>
We never actually see those words. All we have for this question is embeds_after_l0, the whole sequence already turned into vectors and pushed through layer 0. Good thing too, because that is exactly enough to keep going. We drop it into the surviving layers and read the hidden state at each option's last token (the positions in eol_positions). This gives us one vector per option.
# Take one question from train and push it through the surviving layers
sample = train[0]
device = next(model.parameters()).device
# embeds_after_l0 is the sequence already turned into vectors and run through layer 0
embeds = torch.from_numpy(sample["embeds_after_l0"]).unsqueeze(0).to(device).half()
mask = torch.ones(1, embeds.shape[1], dtype=torch.long, device=device)
with torch.no_grad():
outputs = model.model(inputs_embeds=embeds, attention_mask=mask)
hidden = outputs.last_hidden_state[0] # (seq_len, 2048): one vector per token
option_vectors = hidden[
sample["eol_positions"]
] # (n_opts, 2048): one vector per option
print(f"Input sequence: {tuple(embeds.shape[1:])} (tokens, hidden)")
print(f"After the layers: {tuple(hidden.shape)} (tokens, hidden)")
print(f"Option vectors: {tuple(option_vectors.shape)} (options, hidden)")
print(
f"\nSo we get {option_vectors.shape[0]} vectors, one per option, "
f"each {option_vectors.shape[1]} numbers long."
)
Input sequence: (227, 2048) (tokens, hidden) After the layers: (227, 2048) (tokens, hidden) Option vectors: (4, 2048) (options, hidden) So we get 4 vectors, one per option, each 2048 numbers long.
sample
{'idx': 0,
'embeds_after_l0': array([[ 0.02344 , -0.03394 , -0.1338 , ..., -0.01709 , 0.01782 ,
-0.0581 ],
[-0.0337 , -0.02332 , 0.02222 , ..., -0.01929 , 0.001709 ,
0.003937 ],
[ 0.007935 , 0.03906 , 0.002686 , ..., 0.0004578, -0.0007324,
-0.01648 ],
...,
[-0.005005 , -0.00299 , -0.01422 , ..., 0.002228 , 0.01587 ,
0.00238 ],
[ 0.0221 , 0.01068 , 0.005005 , ..., -0.06152 , 0.04907 ,
-0.0752 ],
[-0.01257 , 0.006958 , 0.03613 , ..., -0.01782 , -0.0238 ,
0.0442 ]], shape=(227, 2048), dtype=float16),
'eol_positions': [164, 183, 199, 208],
'correct_idx': 0}
embeds
tensor([[[ 0.0234, -0.0339, -0.1338, ..., -0.0171, 0.0178, -0.0581],
[-0.0337, -0.0233, 0.0222, ..., -0.0193, 0.0017, 0.0039],
[ 0.0079, 0.0391, 0.0027, ..., 0.0005, -0.0007, -0.0165],
...,
[-0.0050, -0.0030, -0.0142, ..., 0.0022, 0.0159, 0.0024],
[ 0.0221, 0.0107, 0.0050, ..., -0.0615, 0.0491, -0.0752],
[-0.0126, 0.0070, 0.0361, ..., -0.0178, -0.0238, 0.0442]]],
device='mps:0', dtype=torch.float16)
hidden
tensor([[-1.3486e+00, -1.7305e+00, -6.1829e-02, ..., 4.5685e-02,
3.6353e-01, 6.1084e-01],
[-4.2389e-02, 3.7079e-02, -1.0138e-01, ..., -8.1592e-01,
7.8674e-02, 1.1211e+00],
[ 4.5430e+00, 4.3831e-03, 1.4197e-01, ..., 4.3828e+00,
-4.9103e-02, 1.2962e-02],
...,
[ 2.7852e+00, 2.2583e-02, -3.3472e-01, ..., 3.0298e-01,
4.1687e-02, 1.5784e-01],
[ 2.6641e+00, -3.8457e+00, -1.1438e+01, ..., -3.2031e-01,
-6.7444e-02, -9.3066e-01],
[ 7.1436e-01, 2.1667e-02, -8.1348e-01, ..., -1.2177e-01,
-1.6431e-01, -3.2051e+00]], device='mps:0', dtype=torch.float16)
Step 2: (I)mplement the Baseline¶
Before we change anything, we want a number to beat. The test set has no answers, so we cannot score there. The move is to hold out part of the 100 labeled questions and check ourselves on those.
The baseline is the most literal reading of the model: run the surviving layers, take the final hidden state at each option's last token, and train a logistic-regression probe to pick the correct option. The forward pass is the slow part, so here is a trick: run it once and cache the option vectors. After that the probe is instant to retrain, which matters because we are about to retrain it a lot.
from tqdm import tqdm
def collect(samples, layer=None, desc="Forward"):
"""Run each question through the surviving layers and grab the hidden state at
every option's last token.
layer=None -> the final (post-norm) hidden state
layer=<int> -> that intermediate layer's hidden state instead
"""
collected = []
for sample in tqdm(samples, desc=desc):
embeds = (
torch.from_numpy(sample["embeds_after_l0"]).unsqueeze(0).to(device).half()
)
mask = torch.ones(1, embeds.shape[1], dtype=torch.long, device=device)
with torch.no_grad():
outputs = model.model(
inputs_embeds=embeds, attention_mask=mask, output_hidden_states=True
)
hidden = (
outputs.last_hidden_state[0]
if layer is None
else outputs.hidden_states[layer][0]
)
collected.append(
{
"vectors": hidden[sample["eol_positions"]]
.float()
.cpu()
.numpy(), # (n_opts, hidden)
"correct": sample.get("correct_idx"),
}
)
return collected
base = collect(train, layer=None, desc="Train (last layer)")
Train (last layer): 100%|██████████| 100/100 [00:24<00:00, 4.06it/s]
What the classifier actually sees¶
Every option becomes one training row. The feature is that option's EOL vector (the 2048 numbers we read at its end-of-line position, taken from some layer N), and the label is whether that option was the correct one. For the baseline, layer N is the last layer; in Step 4 we change N to a middle layer. So a couple of questions, 4 options each, turn into a table like this (vectors shown truncated, and question 1's numbers are made up just to illustrate):
| Question | Option | EOL token position | EOL token vector at layer N (2048-dim) | Is correct? (label) |
|---|---|---|---|---|
| 0 | A | 164 | [-1.35, -0.04, 4.54, …] |
1 |
| 0 | B | 183 | [ 0.22, -0.91, 0.06, …] |
0 |
| 0 | C | 199 | [ 2.78, 0.02, -0.33, …] |
0 |
| 0 | D | 208 | [ 0.71, 0.02, -0.81, …] |
0 |
| 1 | A | 150 | [ 0.04, 1.12, -0.50, …] |
0 |
| 1 | B | 172 | [-0.88, 0.31, 0.19, …] |
0 |
| 1 | C | 190 | [ 1.07, -0.22, 0.44, …] |
1 |
| 1 | D | 205 | [ 0.33, -0.07, 0.92, …] |
0 |
So with 100 training questions you get 100 × 4 = 400 such rows: the features are the vector column, shape (400, 2048), and the labels are the last column, shape (400,), with exactly one 1 per question (100 ones, 300 zeros). The EOL position is only where we read the vector, and layer N is which layer we read it from; neither is fed to the classifier. That (400, 2048) matrix and (400,) label vector are the whole training set, which is what build_features assembles below.
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
def build_features(rows):
"""Stack every option vector into a feature matrix, label=1 for the correct option."""
features = np.concatenate([row["vectors"] for row in rows])
labels = np.concatenate(
[
[option == row["correct"] for option in range(len(row["vectors"]))]
for row in rows
]
)
return features, labels.astype(int)
def question_accuracy(classifier, scaler, rows):
"""For each question, pick the option the probe scores highest, compare to truth."""
correct = 0
for row in rows:
probs = classifier.predict_proba(scaler.transform(row["vectors"]))[
:, 1
] # P(correct) per option
correct += int(np.argmax(probs) == row["correct"])
return correct / len(rows)
# Hold out 20 of the 100 labeled questions to check ourselves (the test set is unlabeled)
rng = np.random.default_rng(42)
order = rng.permutation(len(base))
val_rows = [base[position] for position in order[:20]]
train_rows = [base[position] for position in order[20:]]
train_features, train_labels = build_features(train_rows)
scaler = StandardScaler().fit(train_features)
classifier = LogisticRegression(max_iter=5000).fit(
scaler.transform(train_features), train_labels
)
print("Random guess: 25.0%")
print(
"Baseline (last layer) held-out acc:",
f"{question_accuracy(classifier, scaler, val_rows) * 100:.1f}%",
)
Random guess: 25.0% Baseline (last layer) held-out acc: 35.0%
Step 3: (C)heck for Errors¶
Same habit as Problem 1: spot a symptom, then run a quick probe to confirm the cause. The common ones for this pipeline:
| Sign something is wrong | Likely cause | How to probe and check |
|---|---|---|
| Held-out accuracy sits near 25% | the probe is reading noise, or you are reading a layer with no signal | Confirm it beats always-guess-0; then try a middle layer instead of the last (see Step 4). |
| The number jumps around every time you run it | only 20 held-out questions, so one split is noisy | Average over many random splits, not one, before trusting any change. |
| Held-out accuracy is far below train accuracy | the probe memorized 2048 features from only 400 vectors | Lower C (stronger L2) and watch the train-vs-held-out gap shrink. |
| Held-out accuracy looks suspiciously high | leakage: options of one question landed on both sides of the split, or the scaler was fit on the held-out set | Split by whole question, and fit StandardScaler on the training questions only. |
argmax sometimes returns an option that should not exist |
you are scoring over the wrong set of vectors | Make sure you only stack and argmax over the 4 real option positions. |
| The last layer underperforms an earlier one | the final state is shaped for the output head we deleted | Sweep the layers and pick by cross-validation, do not assume the last is best. |
Step 4: (E)nhance the Solution¶
There are many ways to push past the baseline, so I'll go in order of payoff: first the one change that matters most, then the cheap polish, then a few further ideas. Spend your time on the first one.
The big win: probe a middle layer, not the last¶
The single biggest lever in this whole problem is which layer you read. The last layer is tuned to feed the output head we threw away, so its vectors are arranged for guessing the next token, not for cleanly splitting a right answer from a wrong one. The middle of the network tends to hold more abstract, more linearly separable features, the kind a simple probe can read.
This is a standard probing move: instead of reading only the final state, ask for every layer's hidden states in one forward pass (output_hidden_states=True), probe each layer on its own, and keep the one that scores best by cross-validation over whole questions. You will usually find a middle layer beats the last by a wide margin. Run the sweep below to see which layer wins for this model.
def collect_all_layers(samples, desc="Forward"):
"""Like collect(), but keep every layer's option vectors so we can probe each one."""
collected = []
for sample in tqdm(samples, desc=desc):
embeds = (
torch.from_numpy(sample["embeds_after_l0"]).unsqueeze(0).to(device).half()
)
mask = torch.ones(1, embeds.shape[1], dtype=torch.long, device=device)
with torch.no_grad():
outputs = model.model(
inputs_embeds=embeds, attention_mask=mask, output_hidden_states=True
)
collected.append(
{
"layer_vectors": [
layer_hidden[0, sample["eol_positions"]].float().cpu().numpy()
for layer_hidden in outputs.hidden_states
],
"correct": sample.get("correct_idx"),
}
)
return collected
per_layer = collect_all_layers(train, "All layers")
All layers: 100%|██████████| 100/100 [00:17<00:00, 5.58it/s]
def cv_accuracy(rows, layers, reg_c=0.01, n_splits=50, val_size=20, seed=42):
"""Repeated question-level holdout. `layers` is a list of layer indices to concatenate."""
rng = np.random.default_rng(seed)
accuracies = []
for _ in range(n_splits):
order = rng.permutation(len(rows))
val_ids, train_ids = order[:val_size], order[val_size:]
def make_split(ids):
features = np.concatenate(
[
np.concatenate(
[rows[position]["layer_vectors"][layer] for layer in layers],
axis=1,
)
for position in ids
]
)
labels = np.concatenate(
[
[option == rows[position]["correct"] for option in range(4)]
for position in ids
]
)
return features, labels.astype(int)
train_features, train_labels = make_split(train_ids)
scaler = StandardScaler().fit(train_features)
classifier = LogisticRegression(max_iter=5000, C=reg_c).fit(
scaler.transform(train_features), train_labels
)
correct = 0
for position in val_ids:
question_vectors = np.concatenate(
[rows[position]["layer_vectors"][layer] for layer in layers], axis=1
)
probs = classifier.predict_proba(scaler.transform(question_vectors))[:, 1]
correct += int(np.argmax(probs) == rows[position]["correct"])
accuracies.append(correct / len(val_ids))
return float(np.mean(accuracies)), float(np.std(accuracies))
# Sweep single middle layers, then try concatenating two good ones
for layer in range(10, 19):
mean_acc, std_acc = cv_accuracy(per_layer, [layer])
print(
f"Layer {layer:2d} mean={mean_acc * 100:5.1f}% std={std_acc * 100:4.1f}%"
)
mean_acc, std_acc = cv_accuracy(per_layer, [14, 17])
print(f"Layers 14+17 mean={mean_acc * 100:5.1f}% std={std_acc * 100:4.1f}%")
Layer 10 mean= 60.8% std= 9.4% Layer 11 mean= 67.9% std= 8.7% Layer 12 mean= 70.5% std= 8.1% Layer 13 mean= 73.6% std= 7.7% Layer 14 mean= 75.1% std= 7.8% Layer 15 mean= 74.2% std= 8.1% Layer 16 mean= 73.8% std= 8.6% Layer 17 mean= 74.0% std= 8.2% Layer 18 mean= 73.5% std= 8.1% Layers 14+17 mean= 75.0% std= 8.5%
Cheap polish¶
- Regularize and validate. With 100 questions a logistic regression at the default
C=1will overfit. A much stronger setting likeC=0.01(heavy L2) works far better here. And pick the layer andCwith many small holdouts, not one, since a single split of 20 questions is too noisy to trust. - Combine a couple of layers. Concatenating two good layers (say 14 and 17) gives the probe two views of each option and can raise the score. More than two rarely helps and just hands the probe more to overfit.
- Center each question. The probe only ever ranks the 4 options of one question against each other, never across questions. Subtracting each question's mean option vector before scaling removes what is common to all 4 and leaves the relative signal.
Going further¶
- A non-linear probe. A small MLP can beat the linear probe, but with 400 vectors it overfits fast, so keep it small and regularized.
- Ensembling. Averaging the probabilities of several probes (trained on different layers or seeds) usually gives a small, consistent gain.
import csv
# Pick the layer(s) the sweep liked best, then train on ALL 100 labeled questions
best_layers = [14]
train_features = np.concatenate(
[
np.concatenate([row["layer_vectors"][layer] for layer in best_layers], axis=1)
for row in per_layer
]
)
train_labels = np.concatenate(
[[option == row["correct"] for option in range(4)] for row in per_layer]
).astype(int)
scaler = StandardScaler().fit(train_features)
classifier = LogisticRegression(max_iter=5000, C=0.01).fit(
scaler.transform(train_features), train_labels
)
# Run the forward pass on the test set and predict the highest-scoring option
val_layers = collect_all_layers(val, "Val")
with open("submission.csv", "w", newline="") as file:
writer = csv.writer(file)
writer.writerow(["id", "target"])
for sample, row in zip(val, val_layers):
question_vectors = np.concatenate(
[row["layer_vectors"][layer] for layer in best_layers], axis=1
)
probs = classifier.predict_proba(scaler.transform(question_vectors))[:, 1]
writer.writerow([sample["idx"], int(np.argmax(probs))])
print("Wrote submission.csv")
Val: 100%|██████████| 1000/1000 [03:11<00:00, 5.23it/s]
Wrote submission.csv
IOAI Philippines 2026 NLP Lecture