Skip to main content

Tiny Weights, Mighty Models: The Mystery of Weight Decay

·308 words·2 mins · Download pdf

Why does weight decay induces generalization? Yes, weight gets smaller, Occam’s razor kicks in but still… why does smaller weights yields better generalization? 🧵

One way I think about this is by looking at neural networks as parametrized general program. When you reduce the sizes of parameters, you are reducing the length of the program that reproduces the training data.

Thus weight decay drives NNs towards constructing the minimum size program that can re-generate the training data. Program length is also related to complexity.

Note that we are not yet thinking about reproducing the val/test set, but only training set at this point.

You can also reproduce training data by simply “memorizing” it, for ex, a program that just hard codes each value. This will become a very long program but we can make it smaller by looking at patterns. Imagine replacing hard coded values by for loops and other constructs.

This in effect compresses the long program to smaller program. Generalization and compression are very intimately related.

But why smaller program leads to generalization?

There is usual explanation of nature being frugal/lazy (Least Action Principle) but slightly different way to think of this is by looking at universe as computer. All program created on this computer are coincidental and therefore small but we look at it post-hoc as optimized.

So, for each training set, there exists many programs that can reproduce the training set. For each of those programs, there exists a perfect validation set (No Free Lunch Theorem!). It just happens that it’s not the same validation set created by program on universe as computer.

As the program becomes smaller, the validation set for that program might match more and more with the one that was generated by original program running on computer that is our universe.

When that happens, we call “generalization” has occurred.

Discussion