Training and evaluating a malware classifier
The training loop is the part everyone copies from a tutorial. The evaluation is the part that decides whether any of it meant anything. After ten epochs our byteplot classifier reports 96.5% accuracy on the data it learned from, then turns in 88.54% on data it has never seen. Both numbers are worth reading carefully, because on the Malimg dataset a single accuracy figure measures less than it appears to, and in a red-team setting it describes the model under the easiest possible conditions: in-distribution samples, no adversary, nobody trying to break it.
This entry closes out the build. In the previous pieces we took Windows binaries from the Malimg dataset and reshaped their raw bytes into grayscale images, then stood up a convolutional network to sort them into families. Malimg holds 9,339 samples across 25 families, and the distribution is heavily imbalanced, with the two Allaple families (Allaple.A and Allaple.L) making up a large share of the data while several families contribute only a few dozen samples each (Nataraj et al., 2011). That imbalance matters for how we read every result below. The architecture exists. Now we train it and, more importantly, measure it honestly.
The training loop, and why it looks like this
Training a classifier in PyTorch comes down to repeating four steps over the data: run a forward pass, measure how wrong the output was, push that error backwards through the network, and update the weights. Here is the full loop, with the bookkeeping that lets us watch it learn.
import torch
import time
def train(model, train_loader, n_epochs, verbose=False):
model.train()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
training_data = {"accuracy": [], "loss": []}
for epoch in range(n_epochs):
running_loss = 0
n_total = 0
n_correct = 0
checkpoint = time.time() * 1000
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
_, predicted = outputs.max(1)
n_total += labels.size(0)
n_correct += predicted.eq(labels).sum().item()
running_loss += loss.item()
epoch_loss = running_loss / len(train_loader)
epoch_duration = int(time.time() * 1000 - checkpoint)
epoch_accuracy = compute_accuracy(n_correct, n_total)
training_data["accuracy"].append(epoch_accuracy)
training_data["loss"].append(epoch_loss)
if verbose:
print(f"[i] Epoch {epoch+1} of {n_epochs}: Acc: {epoch_accuracy:.2f}% Loss: {epoch_loss:.4f} (Took {epoch_duration} ms).")
return training_data
Two choices in the first three lines carry most of the weight. CrossEntropyLoss is the standard loss for multi-class classification. In PyTorch it combines LogSoftmax and negative log-likelihood loss in one operation, which means it expects raw logits as input, not softmax probabilities, so the model’s final layer must not apply softmax itself. Doing so would push the values through softmax twice and distort the gradients. The Adam optimiser maintains a separate, adaptive learning rate for each parameter using running estimates of the gradient’s first and second moments, with a default base learning rate of 0.001. That makes it a reliable choice for a first pass, because the model converges without a hand-tuned schedule. Neither choice is the optimal one for this dataset. They are the choices that let you find out quickly whether the architecture works at all, which is what a baseline is for.
Inside the loop, optimizer.zero_grad() is the line people forget. PyTorch accumulates gradients across backward() calls by default, so if you skip the reset, each batch’s gradients add to the last batch’s and the model learns from a mix of stale and current signal. The rest is the canonical sequence. The forward pass produces logits, the loss quantifies the error, loss.backward() computes the gradients through backpropagation, and optimizer.step() applies them. If the mechanics of gradients are hazy, the fundamentals entry in this series covers backpropagation and gradient descent from the ground up.
The one line that separates training from evaluation
The model carries internal state that behaves differently depending on whether it is learning or being judged. Dropout layers zero out a fraction of activations during training but pass everything through at inference. Batch-normalisation layers update running statistics while training, then use those frozen statistics at inference. model.train() and model.eval()toggle this behaviour, and forgetting to switch to eval() is a common reason a model that scored well in training degrades when you go to use it, because the layers are still behaving as if they are mid-training.
def predict(model, test_data):
model.eval()
with torch.no_grad():
output = model(test_data)
_, predicted = torch.max(output.data, 1)
return predicted
torch.no_grad() is the second half of the discipline. During evaluation we are not learning anything, so building the computation graph needed for gradients only wastes memory and time. Wrapping inference in no_grad() tells PyTorch to skip it. The evaluation function then walks the held-out set, counts correct predictions, and reports a single accuracy figure.
def compute_accuracy(n_correct, n_total):
return round(100 * n_correct / n_total, 2)
def evaluate(model, test_loader):
model.eval()
n_correct = 0
n_total = 0
with torch.no_grad():
for data, target in test_loader:
predicted = predict(model, data)
n_total += target.size(0)
n_correct += (predicted == target).sum().item()
return compute_accuracy(n_correct, n_total)
Saving uses torch.jit.script, which converts the model to TorchScript, a serialised representation that can be loaded and run without the original Python class definition. That is what makes the saved .pth portable to an inference environment that does not import your training code.
Running it
With the helpers in place, the runner loads the data, builds the model, trains for ten epochs, saves the result, and evaluates against the test set.
$ python3 main.py
[i] Epoch 1 of 10: Acc: 57.09% Loss: 1.4741 (Took 41128 ms).
[i] Epoch 2 of 10: Acc: 85.01% Loss: 0.4631 (Took 40630 ms).
[i] Epoch 3 of 10: Acc: 89.60% Loss: 0.2880 (Took 39567 ms).
[i] Epoch 4 of 10: Acc: 91.88% Loss: 0.2294 (Took 39464 ms).
[i] Epoch 5 of 10: Acc: 92.97% Loss: 0.2113 (Took 39367 ms).
[i] Epoch 6 of 10: Acc: 93.86% Loss: 0.1744 (Took 39172 ms).
[i] Epoch 7 of 10: Acc: 95.13% Loss: 0.1572 (Took 39804 ms).
[i] Epoch 8 of 10: Acc: 94.81% Loss: 0.1501 (Took 39092 ms).
[i] Epoch 9 of 10: Acc: 96.51% Loss: 0.1188 (Took 39328 ms).
[i] Epoch 10 of 10: Acc: 96.26% Loss: 0.1198 (Took 39125 ms).
[i] Inference accuracy: 88.54%.
The loss falls steadily and the training accuracy climbs from 57% to 85% between the first and second epochs, then flattens around 96% by epoch nine. The rapid early gain reflects the property the whole approach is built on, which Nataraj et al. documented in 2011: binaries from the same family produce grayscale images with similar layout and texture, so the easy, high-population families are learnable within a couple of passes. One thing the script does not give you is a validation curve. It records training accuracy per epoch and a single test figure at the end, so overfitting cannot be observed as it happens. Adding a validation split and plotting training against validation accuracy per epoch is what reveals the point where the two curves diverge, which is the signal that the model has started fitting the training set rather than the families.
Reading the gap, and why accuracy is the wrong headline here
The eight-point spread between 96% training accuracy and 88.54% test accuracy is a generalisation gap: the model fits data it has seen better than data it has not. The size of that gap also moves with the random train/test split, which the script seeds at load time. A headline metric that shifts with the seed is a fragile metric. If your accuracy depends on which samples landed in the test set, you have measured the luck of the split as much as the model.
The deeper problem is that overall accuracy is a poor quality measure on Malimg specifically, and this is a documented limitation rather than an inference. Because the dataset is dominated by the Allaple families, a model can post a high overall accuracy by classifying the large families correctly while failing on the minority ones, and the single number hides that failure entirely. Published work on Malimg byteplot classification has made the same point directly, noting that accuracy alone is not sufficient to judge model quality on an imbalanced dataset like this one. A frequently observed failure case is confusion between Allaple.A and Allaple.L, the two largest and most visually similar families.
The metrics that expose what accuracy hides are standard:
- Per-class precision and recall, which show whether the minority families are being detected at all or quietly absorbed into the dominant classes.
- Macro-averaged F1, which weights every family equally regardless of size, so a model that ignores small families cannot hide behind the Allaple count.
- The confusion matrix, which names exactly which families are being mistaken for which, turning a single number into a map of where the model is weak.
For a malware classifier this is the difference between a defensible result and a misleading one. A family the model never detects is a family an analyst will never be warned about, and 88.54% overall accuracy says nothing about which families those are.
The number an adversary actually reads
Clean test accuracy measures the model on samples drawn from the same distribution it trained on, with nobody trying to fool it. Put the classifier in front of an adversary and the relevant question changes from how often it is right to how often it can be made wrong while the input stays a working executable.
That constraint is what separates malware evasion from adversarial examples in ordinary image classification. A perturbed photo only has to look like a cat to a human while reading as guacamole to a model; the picture has no other job. A malware sample has to remain a valid PE file that still runs, so a perturbation that flips the prediction but corrupts the binary has produced a broken file, not an evasion. Attackers work within that constraint using functionality-preserving manipulations, appending bytes to unused regions, padding sections, or editing fields the loader tolerates. Suciu et al. demonstrated byte-append attacks against the byte-based MalConv model, and Demetrio et al. found MalConv relied so heavily on the DOS header that modifying that header alone reached evasion rates above 86%.
Image-based classifiers do not automatically inherit that fragility, and the research record is mixed in an interesting way. One study assessing image-based malware detection against functionality-preserving attacks reported an evasion rate of roughly 5% for the image classifier, against 44% to 54% for MalConv under black-box conditions, because perturbations that work in pixel space tend to break executability when written back to a real binary. That does not make the image-based model safe. It means its robustness is an open question to be measured rather than assumed, which is exactly what the next entries in this arc do: take this trained classifier and apply these attacks to it directly.
The point to carry forward is narrow and verifiable. The 88.54% we produced is an in-distribution figure, reported without a per-class breakdown and without an adversary in the loop. A defensive accuracy number only means something paired with two things: the class distribution it was measured over, and the threat model it was measured under. The number to distrust most is the one reported without either.