https://transformer-circuits.pub/2023/monosemantic-features
- In DL a single neuron often handles multiple features/logics
- “Superposition”, there are more features than neurons
- Normed vector space: just one where you can calculate magnitude of each vector as a real number
- Banach space: complete normed vector space
- Complete: a subset/system of vectors of a banach space is complete if every element in the Banach space can be approximated arbitrarily well in norm by finite linear combinations
- Overcomplete: if the system contains more than necessary, ie can still approximate with 1 or more vector removed
- Anthropic uses weak dictionary learning to find an overcomplete feature basis in a model exhibiting superposition, called a SAE
- 1 layer transformer with 512 neuron MLP layer (ReLU)
- MLP = Multi layered perceptron, the FFNN
- SAE is trained on MLP activations from 8B data points
- SAE analyzes neuron firings and compiles what kinds of prompts activate them Although each feature can fire multiple neurons???
- Features are extracted by the SAE, 1 hidden layer of much higher dimension
- There exists a complete set of features that can act as the basis for the whole activation space
- Treat the activation of each datapoint (input) as a linear combination of ‘features’
- Output of encoder matrix (weights) = feature activations, decoder weight matrix = feature directions
- Features are the HIDDEN LAYER ACTIVATIONS, overcomplete as there are more hidden layer nodes than neurons (inputs)
- MSE loss with L1 penalty for sparsity
- Scale of training (data amount) mattered a lot, more = more distinguishable features
- The decoder weight matrix reconstructs neuron activations as linear combos of FEATURES !!!!!
- Sparsity enforced that any set of activations is explained by a small number of features
- SAE: unsupervised learning NN, setting a target value to be equal to inputs
- N inputs, N outputs, some hidden layers
- Learns approximation to the identity function, but if we place limits such as max # of hidden neurons, we make it learn a compressed representation and reconstruct the input
- If inputs are random this is hard, but will spot any pattern.
- Sparsity constraint: making neurons inactive most of the time (ReLu) by approximately constraining average activation to sparsity parameter (close to 0)
- Achieve by adding penalty term to the cost function for significant deviations, many of them work
- We decompose into more features than there are neurons
- KL divergence: number of bits needed to transform one probability distribution into another, similar to cross entropy loss…