Explainable AI (XAI) is traditional machine learning (ML) with a focus on model explainability, including dissecting a model’s predictions, coefficients and weights, as well as understanding how and why models make their predictions. Many state-of-the-art models are black boxes – meaning we don’t know exactly what is going on inside them and can’t explain their decisions in a way that humans can understand and digest.
This article focuses on tabular data, other data such as images, graph or natural language data may require different methods. We only look at one way of developing explainability; using inherently explainable models. There are many other methods, the most popular of which is using game theory to understand the model after it has been trained, known as SHAP (Lundberg & Lee (2017)).
Inherently explainable models
ML research is a rapidly changing field and although large, complex models get the most press, there has been a recent surge in XAI-related research. In 2017-2019, 175 papers were published on the topic – three times more than in the previous period (Vilone & Longo (2020)). As adoption of ML has become more widespread, the need for explainability has grown.
We can define inherent explainability through three criteria, proposed by Lipton (2017):
- Simulatability: Can a human walk through the model’s steps in a reasonable amount of time and generate an output? For this article, a reasonable amount of time is a working day.
- Algorithmic Transparency: Do we understand how the model will behave on unseen data?
- Decomposability: Is every aspect of the model explainable including parameters, calculations and engineered features.
Another issue is local explainability. How can we understand why the model made a single prediction? This can be very difficult, which is why methods such as SHAP are so popular.
Explainability and accuracy is often considered a trade-off, however in many cases, inherently explainable models can have comparable performance to black box models (Ning et al. (2020)). It is important to choose a level of explainability that suits the problem and define which explainability criteria are essential for a viable solution.
The most widely used explainable models are Decision Trees and Logistic Regression, however there are many lesser known options. This article looks at TabNet (Arik & Pfister (2019)) which was developed and released by Google in 2019. Traditionally, Neural Networks have not matched the performance of gradient boosting models, such as XGBoost, when applied to tabular data. TabNet aimed to outperform these methods while maintaining some explainability.
The core concepts of TabNet are as follows:
- The model input is raw tabular data without any pre-processing (Figure One, feature transformer) and is trained using gradient descent-based optimisation.
- TabNet uses sequential attention (Figure One, attentive transformer) to choose features at each decision step, enabling interpretability and better learning as the learning capacity is used for the most useful features.
- Feature selection is instance-wise, e.g. it can be different for each row of the training dataset.
- TabNet employs a single deep learning architecture for feature selection and reasoning, this is known as soft feature selection.
- The mask (Figure One, mask) enables interpretability by covering up features for each prediction. We can extract the details of the mask and visualise it to understand the features that are covered up the most.
- We can select the number of steps which increases the number of trainable parameters in the model.
- Each step gets an equal vote in the final prediction of the model, mimicking ensemble methods.
Table One summarises how well TabNet meets the key explainability criteria.
|Simulatability||Low||The mechanisms the model uses to process features are too complex to be recreated by a human in a reasonable amount of time.|
|Decomposability||Low||We typically expect neural networks to generate complex representations of features and allow interactions between features. Depending on how many steps are selected, there could be a lot of weights and coefficients that can’t readily be interpreted.|
|Transparency||Med||We can get an idea of how the model will react to unseen data as we know which features the model uses.|
|Local Explainability||High||By visualising the masks for a single prediction, we can establish which features the model used to make its predictions.|
Table One: TabNet explainability against criteria.
The field of explainable AI is constantly developing and there are already a wide range of models available that can suit a variety of needs. This article has only scratched the surface of what is available. Developing models with explainability in mind doesn’t always have to be a trade-off; many explainable models can still perform as well as black box models while maintaining transparency.
Many projects require transparency, particularly those that implement ML to make decisions about people but there are benefits for any projects such as developing better understanding of model bias, a clearer path to production as a result of stakeholder understanding, and the use of ML as a tool to assist decisions.
Adam Shafi – Data Scientist
Adam Shafi is a data scientist in the Insights & Data practice, with five years analytics and machine learning experience in the music industry and public sector. Get in touch directly at https://www.linkedin.com/in/adamshafi/