Introduction to TabNet¶
Info
Author: Who is Void, published on December 18, 2021, reading time: approximately 6 minutes, WeChat official account article link:
1 Introduction¶
For tabular data, tree models (LightGBM, XGBoost) often perform well. Possible reasons are:
- Easy to construct or already have rich feature libraries.
- The decision manifolds of tree models are hyperplane boundaries (which can be understood as cut out one by one), which perform well for such problems.
As for some tasks, the performance of NN models may only be passable, and we need NN models to participate in the final model ensemble. Smart researchers have designed NN models similar to tree models. The model we will introduce in this article is such a model: TabNet.
2 Constructing decision trees with NNs¶
We may be familiar with decision trees, and the decision boundaries can be seen in the following simple example:

Two features x1 and x2 are respectively divided into four parts by thresholds a and d. So how can NN simulate this process?

It can be seen that the input of the model is also two features x1 and x2, which are first filtered separately by the Mask layer because the tree model also independently selects the feature with the largest splitting gain at each step in the construction process.
Then, the two features are each connected to a fully connected layer with designed weights and bias, and the output is activated by RELU activation function.
Since RELU(x) is x when x>0, and 0 when x<0. So for x1, when x1>a, the final output is [c1x1-c1a,0,0,0], if x1<a, the output is [0,-c1x1+c1a,0,0]. It can be seen here that this is equivalent to dividing by a threshold value a. The two -1 dimensions here are actually used to align and fill dimensions.
Finally, we add the outputs of different features, act on softmax, and get the output vector, such as [0.1,0.5,0.3,0.3]. Each dimension represents the weight of the influence of a certain condition on the final decision. For example, 0.1 represents that the weight of the influence of x1>a on the final decision is only 10%. It is worth mentioning that the updating of model parameters still uses backpropagation, and does not involve the calculation of gain.
3 TabNet structure¶
TabNet has made improvements to the above simple structure.

Features first pass through a BN (batch normalization) layer, and then through the Feature transformer layer. The function of this layer is similar to the fully connected layer mentioned earlier, and its structure is as follows:

It has two parts, the part that is shared by all steps and the decision part that each step has independently. The small structures inside are all residual connections composed of fully connected layers (FC), BN layers, and GLU (Gated Linear Unit) used to perform feature selection. After splitting (selecting some features), feature selection is performed through attentive transformer (mainly by using sparsemax to set some features to 0). Intuitively, features selected by multiple steps before should not be chosen by the model. Finally, the Mask matrix used in the next step is generated.
The feature attribute in the lower right corner represents the global importance of the variables.
4 Conclusion¶
The above is the basic structure of TabNet, which realizes instance-wise feature selection through additive models and attention mechanisms, taking into account the advantages of tree models and NNs. When encountering problems with tabular data, you can try it and see how effective it is.
For those who don't fully understand the theoretical understanding of TabNet, you can also read the original paper or articles by Zhihu experts.
%20-%20Tail%20Pic.png)