Problem 2: Corrupted Transformers¶
Task Description: You are given pretrained transformer models that have been damaged. Your job is to restore them so they perform as well as possible on their task. The models are otherwise standard; only specific parts have been corrupted.
The problem comes in two parts, each showing a different kind of corruption.
Subproblem 1 (Transformer A): Some rows of the token embedding table have been set to zero.
Subproblem 2 (Transformers B and C): You are given two damaged checkpoints of the same model.
- Transformer B: some weights have been set to zero, but the layers are in the correct order.
- Transformer C: fewer weights have been zeroed, but the layers are in the wrong order.
Restrictions:
- You may not use any additional data.
- You may not use any other pretrained model.
- Subproblem 1 only: you may not retrain or fine-tune the model. You must repair the embedding table directly. (Subproblem 2 allows training.)
What you're given¶
Subproblem 1: Transformer A¶
- the corrupted model
val.csv(labeled, for checking your work)test.csv(unlabeled, for your submission)
Subproblem 2: Transformers B and C¶
Two checkpoints of the same model, for a single re-identification task. Each checkpoint is in its own folder (transformer_b/, transformer_c/) with:
config.json(the architecture)model.pt(that checkpoint's corrupted weights)train_.csv(columnstarget= author id,text= message)test_.csv(columnsindex,text, the same test set in both folders)
Task: author re-identification. For each message in test_.csv, return the 5 messages most likely written by the same author. The authors in train and test do not overlap. Scored by mAP@5.
Subproblem 1¶
Step 1: (R)ead the Problem¶
- The goal: given a set of text messages, classify their sentiment (positive, neutral, negative).
- The problem: many words in the model's vocabulary have had their embeddings zeroed out. The model is effectively blind to them. However, BERT uses a WordPiece tokenizer. This means missing words can often be built out of smaller subword pieces that aren't zeroed out.
- That is why
val_dataset.csvis useful: since we cannot train, we need a way to measure if our repair strategy (e.g. replacing zeroed embeddings with the average of their subwords) actually restores the model's accuracy before we submit ontest.csv.
What an embedding table looks like¶
An embedding table is just a big lookup matrix: one row per token in the vocabulary, and each row is that token's vector. Let's read it straight out of the model and look at it.
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load the broken sentiment model and its tokenizer
model = AutoModelForSequenceClassification.from_pretrained("transformer_a")
tokenizer = AutoTokenizer.from_pretrained("transformer_a")
# The embedding table: one row per vocabulary token, each row a vector
embedding_table = model.bert.embeddings.word_embeddings.weight.detach()
print(
"Embedding table shape:", tuple(embedding_table.shape), " (vocab_size, hidden_size)"
)
# Each row is one token's vector. Here are a few words (first 8 numbers of each):
print("\nA few words:")
for word in ["good", "bad", "movie", "love"]:
token_id = tokenizer.convert_tokens_to_ids(word)
preview = [round(number, 3) for number in embedding_table[token_id, :8].tolist()]
print(f"\t{word:>6} (id {token_id:>5}): {preview} ...")
# Some rows were wiped to all zeros: that is the corruption
zero_rows = (embedding_table == 0).all(dim=1)
print(f"\nAll-zero (broken) rows: {int(zero_rows.sum())} of {embedding_table.shape[0]}")
broken_id = int(zero_rows.nonzero()[0])
preview = [round(number, 3) for number in embedding_table[broken_id, :8].tolist()]
print(f"Example broken row (id {broken_id}): {preview} ...")
Loading weights: 100%|██████████| 201/201 [00:00<00:00, 8234.01it/s]
Embedding table shape: (30522, 768) (vocab_size, hidden_size) A few words: good (id 2204): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ... bad (id 2919): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ... movie (id 3185): [-0.002, -0.008, -0.025, -0.022, -0.035, -0.042, -0.002, -0.04] ... love (id 2293): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ... All-zero (broken) rows: 12208 of 30522 Example broken row (id 1): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ...
Repairing one word, by hand¶
Before writing any loop, let's do the whole repair on a single word so the idea is concrete. We'll take commitment, a word whose row was zeroed, and rebuild it from the pieces of itself that survived. One catch: a broken word is still an entry in the vocabulary, so the tokenizer happily returns it whole and gives us no pieces to work with. So we split it ourselves into the largest real sub-pieces (commit + ##ment), both of which are still intact, and average them. Four short steps: (1) confirm commitment is broken, (2) split it into surviving pieces, (3) average them, (4) write the average back into the table.
# Step 1: pick a broken word and show its row is all zeros
embedding_table = model.bert.embeddings.word_embeddings.weight.detach().clone()
vocab = tokenizer.get_vocab() # token string -> row id
word = "commitment"
word_id = vocab[word]
print(f'"{word}" lives in row {word_id} of the table')
print(
"Its vector:",
[round(number, 3) for number in embedding_table[word_id, :8].tolist()],
"...",
)
print("Is the whole row zero (broken)?", bool((embedding_table[word_id] == 0).all()))
# Step 2: split the word into pieces ourselves.
# The tokenizer keeps a known word whole, even when its row is broken:
print("Tokenizer keeps it whole:", tokenizer.tokenize(word))
# So we greedily take the largest real sub-piece each time, skipping the broken word itself
pieces, start = [], 0
while start < len(word):
end = len(word)
while start < end:
substring = word[start:end]
form = (
substring if start == 0 else "##" + substring
) # word-initial or continuation
is_whole_word = start == 0 and end == len(word)
if form in vocab and not is_whole_word:
pieces.append(form)
break
end -= 1
start = end
print("Pieces of", word, "->", pieces)
for piece in pieces:
print(
f"\t{piece:>8} (row {vocab[piece]:>5}): broken? {bool((embedding_table[vocab[piece]] == 0).all())}"
)
Tokenizer keeps it whole: ['commitment'] Pieces of commitment -> ['commit', '##ment'] commit (row 10797): broken? False ##ment (row 3672): broken? False
# Step 3: average the surviving pieces -> the repaired vector for the word
piece_ids = [vocab[piece] for piece in pieces]
repaired = embedding_table[piece_ids].mean(dim=0)
print("Average of the surviving pieces:")
print([round(number, 3) for number in repaired[:8].tolist()], "...")
Average of the surviving pieces: [-0.024, -0.026, -0.053, -0.048, -0.03, -0.023, -0.043, 0.023] ...
# Step 4: drop the repaired vector back into the word's row of the table
embedding_table[word_id] = repaired
print(
f'"{word}" row after repair:',
[round(number, 3) for number in embedding_table[word_id, :8].tolist()],
"...",
)
print("Still all zero?", bool((embedding_table[word_id] == 0).all()))
"commitment" row after repair: [-0.024, -0.026, -0.053, -0.048, -0.03, -0.023, -0.043, 0.023] ... Still all zero? False
A few questions that usually come up:
- Are all words broken into subwords? No. Common words have their own single entry in the vocabulary (
good,movie,loveare each one token). Only rarer words get split into pieces. So a word is either one token, or a handful of subword pieces when the whole word is not in the vocabulary. - How are token ids decided? A token id is just its row number in the vocabulary (its line in
vocab.txt). The order is fixed when the tokenizer is built, so the id is stable, not random. For BERT the first few ids are special tokens ([PAD]=0,[UNK]=100,[CLS]=101,[SEP]=102,[MASK]=103), then the wordpieces. That same id is the row used to look the token up in the embedding table. - What if a subword is not in the embedding table? Every vocabulary token has a row, so a subword always has one. The case that actually matters is when that row is also zeroed. Then you skip it and average only the surviving pieces. If none of a token's pieces survived, fall back to a small random vector.
Step 2: (I)mplement the Baseline¶
We need a yardstick before changing anything. The evaluate function loads the labeled val_dataset.csv, runs the model as a sentiment classifier, compares each prediction to the true label, and returns the F1 score. Then we score the corrupted model as-is to get the baseline number to beat.
import pandas as pd
import torch
from sklearn.metrics import f1_score
from tqdm import tqdm
# The model outputs LABEL_0/1/2, which map to these sentiments
LABEL2ID = {"neutral": 0, "positive": 1, "negative": 2}
def evaluate(model, tokenizer, dataframe):
"""Run the classifier on a labeled dataframe (columns `text`, `labels`) and return macro F1."""
model.eval()
device = next(model.parameters()).device
texts = dataframe["text"].astype(str).tolist()
gold = dataframe["labels"].map(LABEL2ID).tolist()
preds = []
with torch.no_grad():
for start in tqdm(range(0, len(texts), 32), desc="Evaluating"):
batch = tokenizer(
texts[start : start + 32],
return_tensors="pt",
padding=True,
truncation=True,
max_length=128,
).to(device)
logits = model(**batch).logits
preds.extend(logits.argmax(dim=1).cpu().tolist())
return f1_score(gold, preds, average="macro")
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Baseline: score the corrupted model before any repair
model = AutoModelForSequenceClassification.from_pretrained("transformer_a")
tokenizer = AutoTokenizer.from_pretrained("transformer_a")
val = pd.read_csv("transformer_a/val_dataset.csv")
print("Corrupted model F1:", evaluate(model, tokenizer, val))
Loading weights: 100%|██████████| 201/201 [00:00<00:00, 17603.62it/s] Evaluating: 100%|██████████| 79/79 [00:17<00:00, 4.58it/s]
Corrupted model F1: 0.2861937747500025
The baseline fix: rebuild zeroed rows from surviving pieces¶
For each token whose embedding row is all zeros, we rebuild a vector from the pieces of the word that still have one:
- Find the broken rows. The rows of the embedding table that are entirely zero.
- For each broken token, collect its surviving pieces. Look at the subword pieces inside the token (its substrings that are themselves vocabulary tokens) and keep only the ones whose embedding was not zeroed.
- Average them. The mean of those surviving pieces becomes the token's repaired vector.
- Fallback. If a token has no surviving pieces at all, drop in a small random vector so it is not stuck at zero.
- Write the rows back into the model, then re-score on the validation set.
The bet: a word's meaning lives in its pieces, and most pieces survived, so their average lands close to where the original vector was.
# Start from a fresh corrupted model
model = AutoModelForSequenceClassification.from_pretrained("transformer_a")
tokenizer = AutoTokenizer.from_pretrained("transformer_a")
embedding_table = (
model.bert.embeddings.word_embeddings.weight.data
) # editable (vocab, hidden)
vocab = tokenizer.get_vocab() # token string -> id
id_to_token = {index: token for token, index in vocab.items()}
Loading weights: 100%|██████████| 201/201 [00:00<00:00, 16589.69it/s]
embedding_table
tensor([[-0.0102, -0.0615, -0.0265, ..., -0.0198, -0.0372, -0.0097],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0197, -0.0627, -0.0326, ..., -0.0165, -0.0420, -0.0032],
...,
[-0.0218, -0.0556, -0.0135, ..., -0.0043, -0.0151, -0.0249],
[-0.0462, -0.0565, -0.0019, ..., 0.0157, -0.0139, -0.0094],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])
# Peek at the first 20 entries of the vocabulary (token -> row id)
dict(list(vocab.items())[:20])
{'longitude': 20413,
'cancers': 25409,
'bonfire': 28698,
'zimbabwe': 11399,
'ri': 15544,
'##lmer': 23398,
'##lidae': 21595,
'grandchildren': 13628,
'অ': 1347,
'psychiatry': 18420,
'tribute': 7050,
'notoriety': 23215,
'##mel': 10199,
'##和': 30322,
'relied': 13538,
'blah': 27984,
'##tan': 5794,
'mustang': 18851,
'rbi': 16929,
'crossroads': 16760}
zero_rows = (embedding_table == 0).all(dim=1)
broken_ids = zero_rows.nonzero().flatten().tolist()
alive = (~zero_rows).tolist() # fast python lookup: did row `piece_id` survive?
def surviving_pieces(broken_id):
"""Ids of subword pieces inside this token whose embedding still exists."""
word = id_to_token[broken_id]
word = word[2:] if word.startswith("##") else word # surface form
found = []
for start in range(len(word)):
for end in range(start + 1, len(word) + 1):
for form in (
word[start:end],
"##" + word[start:end],
): # word-initial or continuation
piece_id = vocab.get(form)
if piece_id is not None and piece_id != broken_id and alive[piece_id]:
found.append(piece_id)
return found
# Rebuild each broken row as the mean of its surviving pieces (random fallback)
for broken_id in tqdm(broken_ids, desc="Repairing"):
pieces = surviving_pieces(broken_id)
if pieces:
embedding_table[broken_id] = embedding_table[pieces].mean(dim=0)
else:
embedding_table[broken_id] = torch.randn(embedding_table.shape[1]) * 0.02
print("Repaired model F1:", evaluate(model, tokenizer, val))
Repairing: 100%|██████████| 12208/12208 [00:01<00:00, 11549.31it/s] Evaluating: 100%|██████████| 79/79 [00:17<00:00, 4.51it/s]
Repaired model F1: 0.45903143693359755
Step 3: (C)heck for Errors¶
A few things to check for this repair:
- Did the zeros actually go away? After repairing, re-count the all-zero rows. It should drop to near zero (only tokens with no surviving pieces remain). If it did not drop, the write-back never happened.
- How many tokens fell back to random? Those rows are pure noise. If the count is high, your piece-finding is too strict, loosen it before trusting the score.
- Did F1 move on
val? Every change should beat the 0.286 baseline. If a change does not move the number, it did not help. - Eyeball a repaired word. Take a sentiment word that was broken (like
good) and check its new vector looks like a real embedding, similar in scale to the surviving rows, not tiny or huge. - Mind the scale. A mean of pieces is usually fine, but a random fallback can be off-scale. Make sure it roughly matches the magnitude of healthy rows.
Step 4: (E)nhance the Solution¶
A stronger repair¶
The baseline averages every surviving piece equally and uses random noise when nothing survives. Three cheap upgrades, still no training:
- Weight by piece length. A long piece like
happysays more about a word than a single character, so weight each piece by its length in the average. - Repair in passes. Some tokens' pieces were themselves zeroed. Repairing a few times lets rows filled in one pass act as pieces for the next.
- Smarter fallback. When a token has no surviving pieces at all, use the average of all surviving rows (the center of the table) instead of random noise.
# Stronger repair: weight pieces by length, repair in passes, center fallback
model = AutoModelForSequenceClassification.from_pretrained("transformer_a")
tokenizer = AutoTokenizer.from_pretrained("transformer_a")
embedding_table = model.bert.embeddings.word_embeddings.weight.data
vocab = tokenizer.get_vocab()
id_to_token = {index: token for token, index in vocab.items()}
# fallback = the average of all rows that survived (the center of the table)
center = embedding_table[(embedding_table != 0).any(dim=1)].mean(dim=0)
for _ in range(3): # repair in passes: rows filled now become usable pieces next pass
zero_rows = (embedding_table == 0).all(dim=1)
alive = (~zero_rows).tolist()
broken_ids = zero_rows.nonzero().flatten().tolist()
if not broken_ids:
break
for broken_id in tqdm(broken_ids, desc="Repairing"):
pieces = surviving_pieces(broken_id) # reuses the helper from the baseline cell
if pieces:
piece_lengths = torch.tensor(
[len(id_to_token[piece_id].lstrip("#")) for piece_id in pieces],
dtype=torch.float,
) # longer pieces count more
embedding_table[broken_id] = (
embedding_table[pieces] * piece_lengths[:, None]
).sum(0) / piece_lengths.sum()
# any token that never found a piece gets the center vector instead of random noise
for broken_id in (embedding_table == 0).all(dim=1).nonzero().flatten().tolist():
embedding_table[broken_id] = center
print("Enhanced repair F1:", evaluate(model, tokenizer, val))
Loading weights: 100%|██████████| 201/201 [00:00<00:00, 7896.88it/s] Repairing: 100%|██████████| 12208/12208 [00:01<00:00, 8023.47it/s] Repairing: 100%|██████████| 370/370 [00:00<00:00, 989727.35it/s] Repairing: 100%|██████████| 370/370 [00:00<00:00, 1218125.97it/s] Evaluating: 100%|██████████| 79/79 [00:19<00:00, 4.11it/s]
Enhanced repair F1: 0.47578691807159595
What if training is allowed?¶
This subproblem bans training, which is why we repair the table by hand. For contrast, here is the training route: freeze the rest of the model, make only the embedding table trainable, and learn the wiped rows back with a normal classification loop on the labeled data.
# not allowed in this subproblem, shown for contrast
for param in model.parameters():
param.requires_grad = False
model.bert.embeddings.word_embeddings.weight.requires_grad = True
# then a standard training loop on val_dataset: tokenize -> logits -> CrossEntropy -> step
Gradient descent simply fills the zeroed rows with whatever values make the predictions correct. That is the same "just fine-tune it" idea you can use in Subproblem 2, and it is the thing this subproblem forbids, so the lesson stays on reconstructing the table from its own surviving pieces.
Subproblem 2¶
Step 1: (R)ead the Problem¶
- The goal: given the test set of messages, group them by author. For each test message, return the other messages most likely written by the same person. You never see who the authors are, you match them by writing style.
- The model was not trained on
train.csv, but on a separate corpus of similar text (similar distribution). So it already carries a general "same author?" signal that transfers to people it has never seen. - That is why
train.csvis useful: even though the model never learned from it, its authors are labeled, so we can measure mAP@5 there to confirm our repair restored the signal before submitting ontest.csv.
What one block looks like¶
Each block is self-attention + FFN, wrapped in Add & Norm. Here is what that looks like in this real model. If we list the weights of a single block (model.layers.0), we see them:
model.layers.0.attn.Wqkv.weight (2304, 768) # self-attention: build Q, K, V
model.layers.0.attn.Wo.weight (768, 768) # self-attention: combine the result
model.layers.0.mlp_norm.weight (768,) # Add & Norm
model.layers.0.mlp.Wi.weight (2304, 768) # FFN: in
model.layers.0.mlp.Wo.weight (768, 1152) # FFN: out
- Self-attention =
attn.Wqkv(the tokens build query, key, and value so they can look at each other) andattn.Wo(combine what they gathered). - FFN =
mlp.Withenmlp.Wo(process each token on its own). - Add & Norm = the
*_normweights.
Stack 22 of these blocks on top of the embedding table, and that is the whole encoder. Let's look inside the checkpoint to confirm.
import torch
import re
# A .pt file is just a dictionary: weight name -> tensor. So we can look inside.
state_dict = torch.load("transformer_b/model.pt", map_location="cpu", weights_only=True)
# The embedding table: one row per vocab token (token id -> starting vector)
print(
"Embedding table:",
tuple(state_dict["model.embeddings.tok_embeddings.weight"].shape),
)
# Everything inside a single block
print("\nOne block (model.layers.0):")
for name, tensor in state_dict.items():
if name.startswith("model.layers.0."):
print(f"\t{name:35s} {tuple(tensor.shape)}")
# How many blocks are stacked?
blocks = {
int(re.search(r"layers\.(\d+)\.", name).group(1))
for name in state_dict
if ".layers." in name
}
print("\nNumber of blocks:", len(blocks))
Embedding table: (50368, 768) One block (model.layers.0): model.layers.0.attn.Wqkv.weight (2304, 768) model.layers.0.attn.Wo.weight (768, 768) model.layers.0.mlp_norm.weight (768,) model.layers.0.mlp.Wi.weight (2304, 768) model.layers.0.mlp.Wo.weight (768, 1152) Number of blocks: 22
Building a model from the weights¶
Now that we understand what the weights are, let's turn them into a model we can actually run. Two steps: first read the config (the architecture blueprint) and build an empty model from it, then pour the saved weights into that model.
Let's read the config and see which weights the model expects.
from transformers import AutoConfig, AutoModel
# Read the architecture blueprint from config.json (no weights yet)
# Then build an empty model with that architecture (random weights for now)
config = AutoConfig.from_pretrained("transformer_b")
model = AutoModel.from_config(config)
# Ask the empty model which weight names it expects (its "slots")
print("The encoder expects (first 5):")
for name in list(model.state_dict().keys())[:5]:
print(f"\t{name}")
The encoder expects (first 5): embeddings.tok_embeddings.weight embeddings.norm.weight layers.0.attn.Wqkv.weight layers.0.attn.Wo.weight layers.0.mlp_norm.weight
Notice the names the empty encoder expects: embeddings.tok_embeddings.weight, layers.0.attn.Wqkv.weight, and so on.
But the saved checkpoint named those same weights with a model. in front (model.embeddings.tok_embeddings.weight), which we saw when we listed its keys above. load_state_dict fills the model by matching names exactly, so as they are they would not line up. The weights were saved while the encoder was nested one level inside a larger object, which stuck model. on the front of every name.
So we peel off the model. prefix to make the names match, and skip the one extra weight (centers) that is not part of the encoder.
# The saved weights (the same .pt dictionary we inspected above)
state_dict = torch.load("transformer_b/model.pt", map_location="cpu", weights_only=True)
# Rename: drop the "model." prefix so the names match the encoder's slots,
# and keep only the encoder weights (skip the extra "centers" weight)
encoder_state_dict = {
name.removeprefix("model."): tensor
for name, tensor in state_dict.items()
if name.startswith("model.")
}
# Now the names line up, so the weights load cleanly into our model
model.load_state_dict(encoder_state_dict)
print("Loaded weights into", type(model).__name__)
Loaded weights into ModernBertModel
So let's formalize this setup into a new function: build_model()
import torch
from transformers import AutoConfig, AutoModel
def build_model(config_path, weights_path):
"""Build an encoder from a config file and a weights file.
Loads the config and weights, builds the empty model, then loads the weights,
dropping the "model." prefix so names match the encoder's slots and skipping
non-encoder weights (like "centers").
"""
config = AutoConfig.from_pretrained(config_path)
weights = torch.load(weights_path, map_location="cpu", weights_only=True)
model = AutoModel.from_config(config)
encoder_state_dict = {
name.removeprefix("model."): tensor
for name, tensor in weights.items()
if name.startswith("model.")
}
model.load_state_dict(encoder_state_dict)
return model
# Usage:
model = build_model("transformer_b/config.json", "transformer_b/model.pt")
print("Built", type(model).__name__)
Built ModernBertModel
Step 2: (I)mplement the Baseline¶
We first need a development loop to iterate with. We will assume the evaluate function is given, so we only have to pass it a model, the tokenizer, and our predictions.
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
def evaluate(model, tokenizer, examples):
"""Score a model on labeled examples, a list of (text, author) tuples: embed each, then mAP@5 by author."""
texts = [text for text, _ in examples]
labels = [author for _, author in examples]
device = next(model.parameters()).device
model.eval()
def embed(texts):
chunks = []
for start in tqdm(range(0, len(texts), 32), desc="Embedding"):
batch = tokenizer(
texts[start : start + 32],
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(device)
with torch.no_grad():
hidden = model(**batch).last_hidden_state # (batch, tokens, hidden)
mask = batch["attention_mask"].unsqueeze(-1).float()
pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9) # mean-pool
chunks.append(F.normalize(pooled, dim=1).cpu())
return torch.cat(chunks)
def map_at_5(embeddings, labels):
labels = np.asarray(labels)
sims = (embeddings @ embeddings.T).numpy()
np.fill_diagonal(sims, -np.inf) # never retrieve the message itself
average_precisions = []
for query in range(len(labels)):
top5 = np.argpartition(-sims[query], 5)[:5]
top5 = top5[np.argsort(-sims[query][top5])]
relevant = labels[top5] == labels[query]
num_relevant = int((labels == labels[query]).sum() - 1)
hits, score = 0, 0.0
for rank, is_relevant in enumerate(relevant, start=1):
if is_relevant:
hits += 1
score += hits / rank
average_precisions.append(
score / min(num_relevant, 5) if num_relevant else 0.0
)
return float(np.mean(average_precisions))
embeddings = embed(texts)
return map_at_5(embeddings, labels)
So let's just see what the current train set will score if the model is corrupted...
import pandas as pd
from transformers import AutoTokenizer
num_samples_to_evaluate = 1000
device = (
"cuda"
if torch.cuda.is_available()
else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print(f"Using device: {device}")
# --- for transformer B
# Build the corrupted model with our helper
model = build_model(
config_path="transformer_b/config.json",
weights_path="transformer_b/model.pt",
).to(device)
tokenizer = AutoTokenizer.from_pretrained("transformer_b/tokenizer")
# Score it on a random subsample of the labeled train set (faster while developing)
train = (
pd.read_csv("transformer_b/train_.csv")
.dropna(subset=["text", "target"])
.sample(num_samples_to_evaluate, random_state=42)
)
examples = list(zip(train["text"].astype(str), train["target"]))
print(
f"[Transformer B] Corrupted model mAP@5 ({num_samples_to_evaluate} samples):",
evaluate(model, tokenizer, examples),
)
# --- for transformer C
model = build_model(
config_path="transformer_b/config.json",
weights_path="transformer_c/model.pt",
).to(device)
tokenizer = AutoTokenizer.from_pretrained("transformer_b/tokenizer")
# Score it on a random subsample of the labeled train set (faster while developing)
train = (
pd.read_csv("transformer_b/train_.csv")
.dropna(subset=["text", "target"])
.sample(num_samples_to_evaluate, random_state=42)
)
examples = list(zip(train["text"].astype(str), train["target"]))
print(
f"[Transformer C] Corrupted model mAP@5 ({num_samples_to_evaluate} samples):",
evaluate(model, tokenizer, examples),
)
Using device: mps
Embedding: 100%|██████████| 32/32 [00:28<00:00, 1.14it/s]
[Transformer B] Corrupted model mAP@5 (1000 samples): 0.03745333333333333
Embedding: 100%|██████████| 32/32 [00:16<00:00, 1.90it/s]
[Transformer C] Corrupted model mAP@5 (1000 samples): 0.04019
Both corrupted checkpoints score about 0.04 mAP@5, barely above random. The corruption really did break the embeddings. That ~0.04 is our baseline: the number any fix has to beat.
A simple approach: repair the weights¶
We have two broken copies of the same model. B has its layers in the right order but many weights zeroed; C has fewer zeros but its layers are shuffled. Each can fix the other's weakness:
- Recover C's order. Each layer in B has a twin in C (same model, nearly the same weights). For each B layer, find the C layer whose weights look most like it. That tells us where each of C's layers belongs.
- Reorder C so its layers sit in the right positions.
- Merge. Take C's (now-ordered) cleaner weights, and wherever C still has a zero, fill it in from B.
- Build a model from the repaired weights and score it.
import torch.nn.functional as F
# Both checkpoints are {name: tensor} dicts, like we inspected earlier
weights_b = torch.load("transformer_b/model.pt", map_location="cpu", weights_only=True)
weights_c = torch.load("transformer_c/model.pt", map_location="cpu", weights_only=True)
n_blocks = 22
# Fingerprint a block by flattening one big weight (its Wqkv) and normalizing it
def fingerprint(weights, layer):
flat = weights[f"model.layers.{layer}.attn.Wqkv.weight"].flatten().float()
return F.normalize(flat, dim=0)
# 1 + 2. For each B layer, find the C layer that looks most like it (cosine similarity)
fingerprints_b = torch.stack(
[fingerprint(weights_b, layer) for layer in range(n_blocks)]
)
fingerprints_c = torch.stack(
[fingerprint(weights_c, layer) for layer in range(n_blocks)]
)
match = (fingerprints_b @ fingerprints_c.T).argmax(
dim=1
) # match[layer] = the C layer that belongs at that position
# 3. Build repaired weights: pull C's matched layer into each slot, fill its zeros from B
repaired = {}
for name, weight_b in weights_b.items():
if name.startswith("model.layers."):
layer = int(name.split(".")[2])
matched_layer = match[layer].item()
weight_c = weights_c[
name.replace(f"layers.{layer}.", f"layers.{matched_layer}.")
]
else:
weight_c = weights_c[name] # non-layer weights are not shuffled
repaired[name] = torch.where(
weight_c != 0, weight_c, weight_b
) # prefer C, fall back to B
# 4. Build a model from the repaired weights and score it
torch.save(repaired, "repaired.pt")
model = build_model("transformer_b/config.json", "repaired.pt").to(device)
print("Repaired mAP@5:", evaluate(model, tokenizer, examples))
Embedding: 100%|██████████| 32/32 [00:17<00:00, 1.83it/s]
Repaired mAP@5: 0.11806333333333335
A training-based solution (Subproblem 2 allows training)¶
Instead of reconstructing the weights, we can fine-tune the corrupted model on the train authors. Training updates every weight (including the zeroed ones), so the model re-learns useful values, and it pulls each author's messages into a tight cluster.
The setup: put a small classifier head on top of the embedding (one output per train author) and train so each message is classified as its own author. We never use the head afterward. Training it is just the way to shape the encoder. Then we drop the head and score the encoder's embeddings as usual.
import torch.nn as nn
from torch.optim import AdamW
# Start from the corrupted encoder (a fresh copy)
encoder = build_model("transformer_b/config.json", "transformer_b/model.pt").to(device)
# Map each train author to a class id, and add a classifier head on top of the embedding
train_full = pd.read_csv("transformer_b/train_.csv").dropna(subset=["text", "target"])
authors = sorted(train_full["target"].unique())
author_to_id = {author: index for index, author in enumerate(authors)}
head = nn.Linear(encoder.config.hidden_size, len(authors)).to(device)
texts = train_full["text"].astype(str).tolist()
labels = [author_to_id[author] for author in train_full["target"]]
optimizer = AdamW(list(encoder.parameters()) + list(head.parameters()), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()
# One pass over the train set: classify each message as its own author
encoder.train()
for start in tqdm(range(0, len(texts), 16), desc="Training"):
batch = tokenizer(
texts[start : start + 16],
padding=True,
truncation=True,
max_length=256,
return_tensors="pt",
).to(device)
targets = torch.tensor(labels[start : start + 16]).to(device)
hidden = encoder(**batch).last_hidden_state
mask = batch["attention_mask"].unsqueeze(-1).float()
pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9) # mean-pool
loss = loss_fn(head(pooled), targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Drop the head, score just the encoder's embeddings
print("Fine-tuned mAP@5:", evaluate(encoder, tokenizer, examples))
Training: 100%|██████████| 720/720 [04:05<00:00, 2.94it/s] Embedding: 100%|██████████| 32/32 [00:17<00:00, 1.82it/s]
Fine-tuned mAP@5: 0.07982333333333334
Solution 3: repair, then fine-tune¶
The first two approaches are not mutually exclusive. Repair fixes the structure (correct order, most weights restored), and fine-tuning closes the rest. So the strongest move is to start from the repaired weights and fine-tune on top of them: the model begins close to working, so training has little left to fix, and should beat fine-tuning from the corrupted checkpoint.
(Needs the repair cell above to have run, since it saved repaired.pt.)
import torch.nn as nn
from torch.optim import AdamW
# Start from the REPAIRED weights (saved by the repair cell), then fine-tune on top
encoder = build_model("transformer_b/config.json", "repaired.pt").to(device)
train_full = pd.read_csv("transformer_b/train_.csv").dropna(subset=["text", "target"])
authors = sorted(train_full["target"].unique())
author_to_id = {author: index for index, author in enumerate(authors)}
head = nn.Linear(encoder.config.hidden_size, len(authors)).to(device)
texts = train_full["text"].astype(str).tolist()
labels = [author_to_id[author] for author in train_full["target"]]
optimizer = AdamW(list(encoder.parameters()) + list(head.parameters()), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()
# Same training loop as Solution 2, but the encoder starts from the repaired weights
encoder.train()
for start in tqdm(range(0, len(texts), 16), desc="Training"):
batch = tokenizer(
texts[start : start + 16],
padding=True,
truncation=True,
max_length=256,
return_tensors="pt",
).to(device)
targets = torch.tensor(labels[start : start + 16]).to(device)
hidden = encoder(**batch).last_hidden_state
mask = batch["attention_mask"].unsqueeze(-1).float()
pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9) # mean-pool
loss = loss_fn(head(pooled), targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Drop the head, score just the encoder's embeddings
print("Repair + fine-tune mAP@5:", evaluate(encoder, tokenizer, examples))
Training: 100%|██████████| 720/720 [03:56<00:00, 3.04it/s] Embedding: 100%|██████████| 32/32 [00:17<00:00, 1.86it/s]
Repair + fine-tune mAP@5: 0.175235
Step 3: (C)heck for Errors¶
A few quick checks for this problem:
- Is the layer matching one-to-one? If two B layers map to the same C layer, the order recovery is wrong. Switch
argmaxto one-to-one matching. - Did the zeros drop? After merging, the fraction of zero weights should fall sharply. If not, the merge did nothing.
- Did the number move? Every fix should raise mAP@5 over the ~0.04 baseline. If a change does not move it, it did not help.
- Mind train/eval overlap: score on messages you did not train on, or the number is optimistic.
Step 4: (E)nhance the Solution¶
Repair plus fine-tune already reaches ~0.16. Here is where the remaining points live, roughly from quick polish to the bigger changes.
Polish the repair
- Recover the layer order with the Hungarian algorithm rather than
argmax, so no two layers claim the same slot. - Fill the weights that were zeroed in both copies instead of leaving them at zero.
- Merge more carefully: average where both copies have a value, instead of always preferring one.
Tune the fine-tuning
- Train a couple of epochs with learning-rate warmup then cosine decay, a higher learning rate on the head than the backbone, and a touch of label smoothing.
- Validate on a split that holds out whole authors, so the score reflects unseen authors the way the real test does.
Go beyond the embedding.
- Blend the model's embedding with cheap text signals: TF-IDF over characters and words, plus simple style stats like punctuation rate, emoji and link counts, and message length. Authors are consistent in these.
- Re-rank with reciprocal nearest neighbours: boost two messages when each lands in the other's top matches.
- Ensemble encoders: concatenate the embeddings from the fine-tuned model and the raw repaired one before matching.
IOAI Philippines 2026 NLP Lecture