I am going to start this with a summary because this is just a big dump of tables with neural network output values. I have 3 main things that I show here.
- I found that for a model to begin to classify target images wrong, or “forget” them, the output value of the target class goes down at the same time as other class values go up. The model seems to think that the image could be one of several classes.
- The targeted pruning starts to affect the behaviour of the model before the test accuracy starts to drop. This means we could need to remove far fewer parameters than originally thought.
- The outputs generated by the main data, that is not targeted by the pruning gives outputs that are largely unaffected by the pruning. This suggests that the model is functioning normally here.
Before I show too many tables with vales, I should explain some things. I use a linear-quadratic final layer. Linear final layer, and a quadratic cost function during training. This means that during training the outputs of the final layer are pushed have a 1 for a correct class output, and 0 for an incorrect one. Using a softmax final layer, the model is pushed to these values after a softmax function has been applied. I have found that the outputs before the softmax can be about +12 for a correct output, and -4 for an incorrect one. This is just in the models I have looked at, this is not a rule. The important thing to note is that a softmax is normalised to always sum to 1. A linear layer can sum to 1, but does not have a function that does it. The reasons I started doing this are not really important here, you just need to understand what we are looking at. If this were a softmax model, we would expect the outputs to look different.
I was not really looking for anything in particular here, I just wanted to see if the outputs looked like outputs. I concluded pretty quickly that this is not supressing outputs, as discussed previously. The model seems to be behaving normally, not just with regard to the main dataset, but also the target class.
Here are a couple of examples of what happens to the output values of a targeted class. Here is a 7, and this is what happens as parameters are removed. The left column is the number of parameters removed from each of the 2 middle layers (so double this for the total).
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
Zero | -0. | 0.002 | 0.003 | 0.001 | 0.001 | 0.001 | 0.002 | 0.988 | 0.003 | 0. |
500 | 0.194 | -0.041 | -0.054 | 0.251 | -0.073 | -0.098 | -0.065 | 0.456 | 0.036 | 0.443 |
1000 | 0.236 | -0.056 | 0.001 | 0.362 | -0.08 | -0.106 | -0.096 | 0.341 | 0.04 | 0.433 |
2000 | 0.204 | -0.106 | 0.086 | 0.379 | -0.091 | -0.099 | -0.076 | 0.179 | 0.123 | 0.384 |
The first row where no parameters are removed is pretty typical of any datapoint, the value for 7 is close to 1, and the others are close to 0. As more parameters are removed the values for 0, 3, 8, and 9 start to increase.
Here is another, this time an 8. This is interesting because it took a lot more parameters being removed, and this one was still not “forgotten”.
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
Zero | -0.002 | 0. | 0. | -0.002 | 0. | -0.004 | -0. | 0. | 1.01 | -0.002 |
500 | 0.025 | -0.06 | -0.015 | 0.068 | -0.016 | -0.026 | 0.033 | -0.085 | 0.861 | 0.249 |
1000 | 0.094 | -0.064 | -0.052 | 0.085 | 0.07 | 0.015 | 0.011 | -0.134 | 0.773 | 0.245 |
2000 | 0.093 | -0.113 | -0.081 | 0.068 | 0.065 | 0.131 | 0.06 | -0.183 | 0.693 | 0.209 |
5000 | 0.174 | -0.122 | -0.011 | 0.023 | 0.05 | 0.13 | 0.035 | -0.159 | 0.642 | 0.219 |
15000 | 0.154 | -0.078 | -0.018 | -0.014 | 0.245 | 0.092 | 0.062 | -0.111 | 0.436 | 0.303 |
The interesting thing in both of these examples is that when the value for the targeted class goes down, the values for the other classes go up. This is not a normalised vector like a softmax. There is no rule here that says when you reduce one, you need to increase another, this is happening internally in the model. The model seems to be confused by the data point, and thinks is could be one of several classes. The classes always seem to sum to about 1. Which lead me to the next thread.
Below is a table of the sums of output values as parameters are removed. I took the mean of 100 datapoints at each step. I also looked at them individually, and this does show the general trend.
Zero | 100 | 500 | 1000 | 5000 | 10000 | 15000 | |
---|---|---|---|---|---|---|---|
Target Sum | 1.001 | 0.952 | 1.036 | 1.054 | 1.093 | 1.069 | 1.15 |
Target Accuracy | 98.24% | 96.78% | 84.14% | 68.67% | 18.57% | 7.89% | 4.37% |
Main Sum | 1.0012 | 1.001 | 1.009 | 1.01 | 1.023 | 1.039 | 1.069 |
Main Accuracy | 98.59% | 98.65% | 98.61% | 98.59% | 98.55% | 98.45% | 98.22% |
The sums slowly increase, a bit more for the target class than for the main data. I have also included the accuracy here. I find the slight increase less interesting than how consistently close to 1 they are, even when the model is only getting 4% of the target outputs correct.
Here I am going to show the other end of the spectrum when it comes to removing parameters. This is what happens when very few parameters are removed. Again I am looking at 7s. This is the output of the target class of 10 datapoints, alongside the accuracy. These are the values that should be close to 1. I have included the accuracy of predicting the target class here, for all of these the accuracy w.r.t the main data remains unchanged.
Parameters Removed | Target Accuracy | Output Values | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|
zero | 0.9824 | 0.988 | 1.021 | 0.95 | 1.002 | 1.062 | 1.012 | 1.076 | 0.97 | 1.03 | 1.029 |
10 | 0.9776 | 0.938 | 0.966 | 0.917 | 0.992 | 1.045 | 0.962 | 0.999 | 0.914 | 0.982 | 0.95 |
20 | 0.9766 | 0.935 | 0.955 | 0.917 | 0.987 | 1.03 | 0.96 | 0.99 | 0.911 | 0.976 | 0.947 |
50 | 0.9747 | 0.851 | 0.871 | 0.858 | 0.945 | 0.952 | 0.885 | 0.915 | 0.849 | 0.894 | 0.846 |
100 | 0.9684 | 0.753 | 0.767 | 0.818 | 0.915 | 0.879 | 0.821 | 0.863 | 0.797 | 0.807 | 0.825 |
200 | 0.9581 | 0.6 | 0.625 | 0.741 | 0.836 | 0.777 | 0.721 | 0.76 | 0.711 | 0.65 | 0.704 |
300 | 0.9406 | 0.557 | 0.594 | 0.683 | 0.811 | 0.73 | 0.666 | 0.743 | 0.685 | 0.612 | 0.644 |
400 | 0.9202 | 0.501 | 0.545 | 0.659 | 0.747 | 0.664 | 0.595 | 0.757 | 0.624 | 0.568 | 0.611 |
500 | 0.8715 | 0.46 | 0.508 | 0.649 | 0.693 | 0.634 | 0.56 | 0.74 | 0.579 | 0.532 | 0.591 |
This shows that there is a significant effect on the behaviour of the model after pruning very few parameters.
Well I already have lots of tables in here, a couple more can’t hurt. These are the outputs of 5 datapoints from the main data from before the model was pruned, and after 15,000 parameters had been removed.
Zero | |||||||||
---|---|---|---|---|---|---|---|---|---|
-0.005 | 0.001 | 1.027 | -0.007 | -0.003 | -0.005 | -0.003 | -0.001 | -0.001 | -0. |
-0.001 | 0.998 | -0.002 | -0.001 | 0.002 | -0.001 | -0. | 0.002 | 0.002 | -0. |
1.026 | -0.001 | -0.003 | -0.005 | -0.002 | -0.004 | -0.002 | -0.001 | -0.004 | -0.002 |
0.001 | 0.004 | 0.004 | 0.003 | 0.972 | 0.003 | 0.003 | 0.006 | 0.004 | 0.004 |
-0.001 | 1.001 | -0.001 | 0.001 | 0. | -0.002 | 0. | 0.001 | 0. | 0. |
15,000 | |||||||||
---|---|---|---|---|---|---|---|---|---|
0.011 | 0.003 | 1.009 | 0.006 | 0.036 | -0.002 | 0.002 | -0.002 | -0.006 | 0.018 |
-0.01 | 0.98 | 0.012 | 0.005 | -0.002 | -0.003 | 0.01 | 0.012 | -0.006 | 0.033 |
0.972 | 0.044 | 0.015 | -0.015 | 0.034 | 0.001 | 0.013 | 0.004 | -0.012 | 0.034 |
0.021 | 0.044 | -0.005 | 0.033 | 0.81 | 0.044 | 0.049 | 0.03 | 0.014 | 0.015 |
0.003 | 0.945 | 0.014 | 0.012 | -0. | 0.015 | 0.013 | 0.026 | -0.001 | 0.042 |
These are the same 5 datapoints. I was going to try and be clever and pick a bunch and say guess which is which, suggesting that you can’t tell the difference, but there is a difference. The main data test accuracy has dropped at this point from 98.6% to 98.4%. The 4th row here has changed significantly, but is still classified correctly. This is a remarkably small difference overall
This blog took me a bit longer than I was expecting. I thought that I would have a look and hopefully confirm that the model was behaving in a reasonable way. After thinking about it a bit I realised that this sort of analysis is will probably be a useful tool when deploying this method. If I were working on a more complicated model and needed to target a lot different behaviours, it is encouraging to know that I probably don’t need to remove the equivalent of 30,000 parameters each time.
Here I talk about the model being confused, and what it thinks, this is just useful language to use, I am not intending to antropomorphise the model. There is probably a limit to how much we can reason from the outputs. With the example of the 8 above, the first class value to increase is 9. The model thinks this could be a 9, and both of these are drawn with a small circle in the top half of the image, makes sense to me. By the last row, the value for 4 is also quite high, and the value for 7 is negative. This might not mean anything, we could just be dealing with a slightly broken model that is behaving unpredictably. If it starts to make no sense, there is no rule that says it has to. I am amazed it SEEMS to make sense as much as it does.
The model being broken, or unpredictable, or having high target class accuracy even though the outputs are being affected are not the end of the story. Whether we are removing a lot or parameters, or very few, this method would be one step in a fine tuning process. In my experience, performance comes back with very little retraining.
Next
Supressing outputs. I mentioned this in earlier blogs and it doesn’t seem to be the problem, so I should really discuss whether I was off the mark with this.
Note on the formatting:
I am making these blogs in wordpress, and I could not get the tables to render in a readable way using markdown. My workaround is to export HTML from notebooks, and this is the result. Not a web developer, sorry.