Welcome to Astrapia’s documentation!
Astrapia is a Python framework for comparing and evaluating tabular post-hoc explainers. Explainers can be used to understand hte behaviour of modern opaque and intransparent machine learning models. Still, some explainer are better than others. This framework is designed to compare them using a generalised set of metrics.
Astrapia is not a framework for ranking explainers. It barely aids users in judging advantages and disadvantages of different explainers.
Table of Contents
Quickstart
See the Installation section to learn how to install Astrapia.
Currently, we offer two examples: UCI adult dataset and UCI breast cancer* dataset. These examples can be found under `notebooks/AstrapiaComparatorDemo.ipynb Here we show you how to use Astrapia to compare different explainers using the UCI adult dataset. First, navigate into data/adult/ and run
$ python setup_adult.py
Files for the datasets will be generated under the corresponding folder. Now load the dataset:
data = dataset.load_csv_data('adult', root_path='../data')
Import the dependencies
import astrapia as xb
from astrapia import explainers, dataset
from astrapia.comparator import ExplainerComparator
from astrapia.visualize_metrics import print_metrics, load_metrics_from_json
import sklearn.ensemble
Then, train a machine learning classifier that you want to explain.
rf = sklearn.ensemble.RandomForestClassifier(n_estimators=50, n_jobs=5)
rf.fit(xb.utils.onehot_encode(data.data, data), data.target.to_numpy().reshape(-1))
pred_fn = lambda x: rf.predict_proba(xb.utils.onehot_encode(x, data))
Prepare post-hoc explainers that you want to compare. Here we chose LIME and Anchors.
ex_lime = explainers.LimeExplainer(data, pred_fn, discretize_continuous=False)
ex_anchors = explainers.AnchorsExplainer(data, pred_fn, 0.9)
Astrapia offers a convenient interface to compare between explainers by instantiating a ExplainerComparator class and appending the explainer to it:
comp = ExplainerComparator()
comp.add_explainer(ex_anchors, 'ANCHORS 0.9')
comp.add_explainer(ex_lime, 'LIME')
Choose an instance or multiple instances to explain:
comp.explain_instances(data.data.iloc[[0]]) # single instance
or
comp.explain_instances(data.data.iloc[[111, 222, 333, 444]]) # multiple instances
Store metric data as json and assert that storing and reloading data does not modify it.
metric_data = comp.get_metric_data()
comp.store_metrics()
assert load_metrics_from_json('metrics.json') == metric_data
To visualize metrics as tables or bar charts:
# show all explainers
print_metrics(metric_data, plot='table', show_metric_with_one_value=True)
print_metrics(metric_data, plot='bar', show_metric_with_one_value=False)
# show single explainer result
print_metrics(metric_data, explainer='ANCHORS 0.9')
print_metrics(metric_data, plot="bar", explainer='LIME')