Skip to Content

Developing Explainable AI using inherently explainable machine learning models

Capgemini
5 Aug 2021

Explainable AI is machine learning with a focus on how we can understand a models output, this article looks at TabNet, a neural network built with explainability in mind

Explainable AI (XAI) is traditional machine learning (ML) with a focus on model explainability, including dissecting a model’s predictions, coefficients and weights, as well as understanding how and why models make their predictions. Many state-of-the-art models are black boxes – meaning we don’t know exactly what is going on inside them and can’t explain their decisions in a way that humans can understand and digest.

This article focuses on tabular data, other data such as images, graph or natural language data may require different methods. We only look at one way of developing explainability; using inherently explainable models. There are many other methods, the most popular of which is using game theory to understand the model after it has been trained, known as SHAP (Lundberg & Lee (2017)).

Inherently explainable models

ML research is a rapidly changing field and although large, complex models get the most press, there has been a recent surge in XAI-related research. In 2017-2019, 175 papers were published on the topic – three times more than in the previous period (Vilone & Longo (2020)). As adoption of ML has become more widespread, the need for explainability has grown.

We can define inherent explainability through three criteria, proposed by Lipton (2017):

  • Simulatability: Can a human walk through the model’s steps in a reasonable amount of time and generate an output? For this article, a reasonable amount of time is a working day.
  • Algorithmic Transparency: Do we understand how the model will behave on unseen data?
  • Decomposability: Is every aspect of the model explainable including parameters, calculations and engineered features.

Another issue is local explainability. How can we understand why the model made a single prediction? This can be very difficult, which is why methods such as SHAP are so popular.

Explainability and accuracy is often considered a trade-off, however in many cases, inherently explainable models can have comparable performance to black box models (Ning et al. (2020)). It is important to choose a level of explainability that suits the problem and define which explainability criteria are essential for a viable solution.

The most widely used explainable models are Decision Trees and Logistic Regression, however there are many lesser known options. This article looks at TabNet (Arik & Pfister (2019)) which was developed and released by Google in 2019. Traditionally, Neural Networks have not matched the performance of gradient boosting models, such as XGBoost, when applied to tabular data. TabNet aimed to outperform these methods while maintaining some explainability.

TabNet

Figure One: TabNet Architecture diagram. The red arrows are to ensure clarity in overlapping arrows.
Figure One: TabNet Architecture diagram. The red arrows are to ensure clarity in overlapping arrows.

The core concepts of TabNet are as follows:

  • The model input is raw tabular data without any pre-processing (Figure One, feature transformer) and is trained using gradient descent-based optimisation.
  • TabNet uses sequential attention (Figure One, attentive transformer) to choose features at each decision step, enabling interpretability and better learning as the learning capacity is used for the most useful features.
  • Feature selection is instance-wise, e.g. it can be different for each row of the training dataset.
  • TabNet employs a single deep learning architecture for feature selection and reasoning, this is known as soft feature selection.
  • The mask (Figure One, mask) enables interpretability by covering up features for each prediction. We can extract the details of the mask and visualise it to understand the features that are covered up the most.
  • We can select the number of steps which increases the number of trainable parameters in the model.
  • Each step gets an equal vote in the final prediction of the model, mimicking ensemble methods.

Table One summarises how well TabNet meets the key explainability criteria.

MetricScoreReason
SimulatabilityLowThe mechanisms the model uses to process features are too complex to be recreated by a human in a reasonable amount of time.
DecomposabilityLowWe typically expect neural networks to generate complex representations of features and allow interactions between features. Depending on how many steps are selected, there could be a lot of weights and coefficients that can’t readily be interpreted.
TransparencyMedWe can get an idea of how the model will react to unseen data as we know which features the model uses.
Local ExplainabilityHighBy visualising the masks for a single prediction, we can establish which features the model used to make its predictions.

              Table One: TabNet explainability against criteria.

Conclusion

The field of explainable AI is constantly developing and there are already a wide range of models available that can suit a variety of needs. This article has only scratched the surface of what is available. Developing models with explainability in mind doesn’t always have to be a trade-off; many explainable models can still perform as well as black box models while maintaining transparency.

Many projects require transparency, particularly those that implement ML to make decisions about people but there are benefits for any projects such as developing better understanding of model bias, a clearer path to production as a result of stakeholder understanding, and the use of ML as a tool to assist decisions.