可解释机器学习的SHAP分析

作者: pdnbplus | 发布时间: 2024/10/25 | 阅读量: 306

可解释机器学习的SHAP分析 -- 潘登同学的机器学习笔记

使用 Shapley 值的可解释 AI 简介

  • 本文首先介绍了 shapley 值的概念,通过一个LoL比赛的例子,拆解了 shapley 值的计算方法,并介绍了其中的数学方法。
  • 本文将对使用 Shapley 值解释机器学习模型的介绍, 主要举例讲解了Shapley用于各种机器学习算法的解释方法。

Shapley 值的数学原理

参考链接: https://www.zhihu.com/question/23180647/answer/3221458936

Shapley Value,这是合作博弈论中最重要、最基础的概念之一。Shapley value的提出主要是为了公平地衡量合作过程中每个人的实际贡献。例如一次LoL比赛中,每方有三位英雄出场,且胜利目标为最高的杀人数。其中,射手的杀人能力最强,肉盾可以保护队友,奶妈可以给队友回血。在这个简单的博弈中,我们往往发现结果是射手+肉盾+奶妈的组合,能够打败三个射手的组合。但是如果没有队友之间的配合,射手应该有最强的击杀能力,因此三个射手才应该胜利。这就是合作博弈的精髓:整体大于局部之和。那么问题来了:现在如果针对三个人的贡献发工资(回报),谁收的更多?

  • Shapley Value认为,需要考虑合作带来的边际贡献。例如肉盾本身杀人能力不行,但是和射手组合后,团队的杀人能力猛增,这个时候需要赋予肉盾很大的边际贡献,因而发到更高的工资。

合作博弈(Cooperative Game)中,一个合作项目 C=(N,v)C=(N,v) 由 n 个agent(又称玩家,player) N=1,2,...,nn2N={1,2,...,n|n≥2}共同完成,其中,每些agent组成的子集 S 为这个项目做出的价值(value)为 v(S)v(S)

假设射手 a1a_1 的场均击杀数为8,坦克 a2a_2 的场均击杀数为4,奶妈 a3a_3 的场均击杀数为2,有:
v({1})=8,v({2})=4,v({3})=2v ( \left\{ 1 \right\} ) = 8 , v ( \left\{ 2 \right\} ) = 4 , v ( \left\{ 3 \right\} ) = 2

假设总场均击杀数为合作项目。射手+坦克:坦克可以帮射手抵挡伤害,射手会有更持久的发挥,二人合作的场均击杀数到达20。射手+奶妈:奶妈比较脆,直接跟射手可能被秒,且射手需要保奶妈,合作效率不高,场均击杀数仍为8(纯假设,有些奶妈能暴杀)。坦克+奶妈:持久战法,不用回城,场均击杀数达到12。

v({1,2})=20,v({1,3})=8,v({2,3})=12v ( \left\{ 1 , 2 \right\} ) = 2 0 , v ( \left\{ 1 , 3 \right\} ) = 8 , v ( \left\{ 2 , 3 \right\} ) = 1 2

三个放在一起,爆发全能,场均击杀数达到30,有
v({1,2,3})=30v ( \left\{ 1 , 2 , 3 \right\} ) = 30

要计算Shapley Value,需要先计算边际贡献,这个边际贡献往往不等于新成员独立的贡献。例如,射手+坦克联盟加入新成员奶妈后,场均击杀数提升了10,这就是边际贡献。而奶妈的独立贡献仅为2。

正式地,我们定义边际贡献Δi(S)\Delta i ( S )为:
Δi(S)=v(S{i})v(S)\Delta i ( S ) = v ( S \cup \left\{ i \right\} ) - v ( S )

由此获得边际贡献表:

P Order Δ(a1)\Delta (a_1) Δ(a2)\Delta (a_2) Δ(a3)\Delta (a_3)
1/6 123 8 v({1,2})v({1})=12v ( \left\{ 1 , 2 \right\} ) - v ( \left\{ 1 \right\} )=12 v({1,2,3})v({1,2})=10v ( \left\{ 1 , 2 , 3 \right\} ) - v ( \left\{ 1 , 2 \right\} )=10
1/6 132 8 v({1,2,3})v({1,3})=22v ( \left\{ 1 , 2, 3 \right\} ) - v ( \left\{ 1, 3 \right\} )=22 v({1,3})v({1})=0v ( \left\{ 1 , 3 \right\} ) - v ( \left\{ 1 \right\} )=0
1/6 213 v({1,2})v({2})=16v ( \left\{ 1 , 2 \right\} ) - v ( \left\{ 2 \right\} )=16 4 v({1,2,3})v({1,2})=10v ( \left\{ 1 , 2 , 3 \right\} ) - v ( \left\{ 1 , 2 \right\} )=10
1/6 231 v({1,2,3})v({2,3})=18v ( \left\{ 1 , 2,3 \right\} ) - v ( \left\{ 2,3 \right\} )=18 4 v({2,3})v({2})=8v ( \left\{ 2 , 3 \right\} ) - v ( \left\{ 2 \right\} )=8
1/6 312 v({1,3})v({3})=6v ( \left\{ 1 , 3 \right\} ) - v ( \left\{ 3 \right\} )=6 v({1,2,3})v({1.3})=22v ( \left\{ 1, 2 , 3 \right\} ) - v ( \left\{ 1.3 \right\} )=22 2
1/6 321 v({1,2,3})v({2,3})=18v ( \left\{ 1 ,2, 3 \right\} ) - v ( \left\{ 2,3 \right\} )=18 v({2,3})v({3})=10v ( \left\{ 2 , 3 \right\} ) - v ( \left\{ 3 \right\} )=10 2

在这里就可以很轻松地引入Shapley Value了。Shapley Value定义每个agent得到的回报。正式地,aia_i的shapley value为它的边际贡献之和除以 N 的最大子集数:

φi(v)=1N!j=1N!Δi(Sj)\varphi _ { i } ( v ) = \frac { 1 } { | N | ! } \sum _ { j = 1 } ^ { | N | ! } \Delta i ( S _ { j } )

如果从上表出发, i的shapley值就是 Delta(ai)Delta (a_i) 列的值之和。例如对于 i=1i=1

φ1(v)=1N!j=1N!Δ(a1)(Sj)=16(8+8+16+18+6+18)=746\varphi _ { 1 } ( v ) = \frac { 1 } { | N | ! } \sum _ { j = 1 } ^ { | N | ! } \Delta (a_1) ( S _ { j } ) = \frac { 1 } { 6 } ( 8 + 8 + 1 6 + 1 8 + 6 + 1 8 ) = \frac { 7 4 } { 6 }
同理可得
φ2=746,φ3=326\varphi _ { 2 } = \frac { 7 4 } { 6 } , \varphi _ { 3 } = \frac { 3 2 } { 6 }

得每个职业获得的回报(即角色的重要性)

w1=41.1%,w2=41.1%,w3=17.8%w _ { 1 } = 4 1 . 1 \% , w _ { 2 } = 4 1 . 1 \% , w _ { 3 } = 1 7 . 8 \%

结论:肉盾与射手一样重要,且都比奶妈重要很多(一倍以上)。

公平的回报分配机制应该具有什么性质

  • 有效性 Efficiency

所有玩家的Shapley Value之和等于总合作的价值,因此所有回报都分配给了玩家,无抽成。

iNψi(v)=ν(N)\sum _ { i \in N } \psi _ { i } ( v ) = \nu ( N )

  • 对称性 Symmetry

如果两个玩家与所有其他玩家集组成的集合都有相同的价值,则称这两个玩家为可交换的。正式地,两个可交换的玩家满足

v(S{i})=v(S{j})ifS,i,jSv ( S \cup \left\{ i \right\} ) = v ( S \cup \left\{ j \right\} ) \quad \text{if} \quad \forall S , i , j \notin S

当两个玩家满足交换性时,它们应该获得相同的回报。因此其shapley值满足

ψi(N,v)=ψj(N,v)ifij\psi _ { i } ( N , v ) = \psi _ { j } ( N , v ) \quad \text{if} \quad i \Leftrightarrow j

  • 伪玩家性质(Dummy Player)

shapley值的伪玩家性质规定,一个伪玩家对任何合作的贡献均为零。正式地,一个伪玩家满足

S,v(S{i})=v(S)\forall S , v ( S \cup \left\{ i \right\} ) = v ( S )
因为伪玩家没有任何贡献,所以它也不应该得到任何回报。

ψi(N,v)=0if i is a dummy player\psi _ { i } ( N , v ) = 0 \quad \text{if i is a dummy player}

  • 可加性 Additivity

又名Linearality(线性)。如果一个博弈可以被分为两个部分 v=v1+v2v=v_1+v_2 ,则我们可以用加法分解它们的回报。正式地,如果存在博弈(N,v1+v2)(N, v_1+v_2) 且它可以被分解为 (v1+v2)(S)=v1(S)+v2(S)(v_1+v_2)(S)=v_1(S)+v_2(S) ,则对于任何 v1,v2v_1, v_2 都有

ψi(N,v1+v2)=ψi(N,v1)+ψi(N,v2)\psi _ { i } ( N , v _ { 1 } + v _ { 2 } ) = \psi _ { i } ( N , v _ { 1 } ) + \psi _ { i } ( N , v _ { 2 } )

正式定义

研究已经严格证明,有且仅有一个 ψ\psi 方程同时满足上面三个性质,这就是shapley value。

ψi(N,v)=1N!SN{i} S!(NS1)![v(S{i})v(S)]\psi _ { i } ( N , v ) = \frac { 1 } { | N | ! } \sum _ { S \in N \diagdown \left\{ i \right\} } \ | S | ! \left( | N | - | S | - 1 \right) ! \left[ v \left( S \cup \left\{ i \right\} \right) - v \left( S \right) \right]

最左边的 1N!\frac { 1 } { | N | ! } 是组合的概率,sum中的每一项都有不同的 SS 。这个 SSN{i}N \diagdown \left\{ i \right\}(表示除去该玩家的集合) 的任意子集(可为空集)。例如上面的例子,当 i=1i=1 时, S=,{2},{3},{2,3}S=\empty ,\{2\},\{3\},\{2,3\} 都可能成立。

总的来说,Shapley Value反映了一个agent边际贡献的期望。

解释线性回归模型

下载数据集

数据集地址: https://www.dcc.fc.up.pt/~ltorgo/Regression/cal_housing.html

该数据集由 1990 年加利福尼亚州的 20,640 个房屋街区组成,我们的目标是从 8 个不同的特征预测房价中位数的自然对数:

  1. MedInc - 区块组的收入中位数
  2. HouseAge - 区块组中的中位房屋年龄
  3. AveRooms - 每户平均房间数
  4. AveBedrms - 每户平均卧室数量
  5. Population - 区块组总体
  6. AveOccup - 平均住户人数
  7. Latitude - 区块组纬度
  8. Longitude - 区块组经度
import sklearn
import shap
import pandas as pd

domain_path = r"C:\Users\pdnbplus\Downloads\cal_housing\cal_housing\CaliforniaHousing\cal_housing.domain"
with open(domain_path, 'r') as file:
    domain_content = file.readlines()

# 打印域文件内容
for line in domain_content:
    print(line.strip())

column_names = ['longitude', 'latitude', 'housingMedianAge', 'totalRooms', 'totalBedrooms', 'population', 'households', 'medianIncome', 'medianHouseValue']

data_path = r"C:\Users\pdnbplus\Downloads\cal_housing\cal_housing\CaliforniaHousing\cal_housing.data"

# 使用 pandas 读取数据
data = pd.read_csv(data_path, sep=',', header=None, names=column_names)
# 查看数据前几行
print(data.head())
# 检查缺失值 无缺失值
print(data.isnull().sum())

# 取1000条数据进行分析
data = data.sample(n=1000, random_state=42)
# 将目标变量从数据集中分离出来
y = data.pop('medianHouseValue')

# 现在 X 包含了特征,y 包含了目标变量
X = data

构建线性模型 检查模型系数

理解线性模型的最常见方法是检查为每个特征学习的系数。这些系数告诉我们,当我们更改每个输入特征时,模型输出发生了多大的变化:

# 采样 100 个实例作为背景分布
X100 = shap.utils.sample(X, 100)

# 训练一个简单的线性模型
model = sklearn.linear_model.LinearRegression()
model.fit(X, y)

print("Model coefficients:\n")
for i in range(X.shape[1]):
    print(X.columns[i], "=", model.coef_[i].round(5))
Model coefficients:

longitude = -43300.7596
latitude = -42607.04902
housingMedianAge = 1047.04257
totalRooms = -3.1356
totalBedrooms = 70.34237
population = -61.29787
households = 130.78059
medianIncome = 36736.79312

虽然系数非常适合告诉我们更改输入特征的值时会发生什么,但它们本身并不是衡量特征整体重要性的好方法。这是因为每个系数的值取决于输入要素的比例。例如,如果我们以分钟而不是年为单位来测量房屋的年龄,那么 housingMedianAge 特征的系数将变为 1047 / (365∗24∗60) = 0.002 。显然,自房屋建成以来的年数并不比分钟数更重要,但它的系数值要大得多。这意味着系数的大小不一定是衡量特征在线性模型中重要性的良好指标。

绘制偏依赖图

要理解一个特征在模型中的重要性,我们需要了解两个方面:一是改变该特征如何影响模型的输出,二是该特征值的分布情况。为了在一个线性模型中可视化这一点,我们可以构建一个经典的偏依赖图(Partial Dependence Plot, PDP),并在 x 轴上用直方图展示该特征值的分布。

shap.partial_dependence_plot(
    "medianIncome",
    model.predict,
    X100,
    ice=False,
    model_expected_value=True,
    feature_expected_value=True,
)

在这里插入图片描述

上图中的灰色水平线代表了模型应用于加州住房数据集时的预期值。灰色垂直线代表了中位收入特征的平均值。请注意,蓝色的偏依赖图线(当我们固定中位收入特征为某个特定值时,模型输出的平均值)总是通过两条灰色预期值线的交点。我们可以将这个交点视为相对于数据分布的“中心点”。当我们接下来讨论 Shapley 值时,这种居中的影响将会变得更加清晰。

从偏依赖图中读取 SHAP 值

机器学习模型基于 Shapley 值的解释背后的核心思想是使用合作博弈论的公平分配结果,在其输入特征之间分配模型输出 f(x)f(x) 的信用。为了将博弈论与机器学习模型联系起来,既需要将模型的输入特征与游戏中的玩家相匹配,又需要将模型函数与游戏规则相匹配。由于在博弈论中,玩家可以加入或不加入游戏,因此我们需要一种方法来让功能“加入”或“不加入”模型。定义特征“加入”模型的含义的最常见方法是,当我们知道该特征的值时,该特征已经“加入模型”,而当我们不知道该特征的值时,它没有加入模型。为了在只有一部分 SS 特征是模型时评估现有模型 ff ,我们使用条件期望值公式将其他特征进行积分。这种表述可以有两种形式:

E[f(X)XS=xS]orE[f(X)do(XS=xS)]E [ f ( X ) \mid X _ { S } = x _ { S } ] \\ \text{or} \\ E [ f ( X ) \mid d o ( X _ { S } = x _ { S } ) ]

在第一种形式中,我们知道 S 中特征的值,因为我们观察它们。在第二种形式中,我们知道 S 中特征的值,因为我们设置了它们。一般来说,第二种形式通常是更可取的,因为它告诉我们如果我们干预并更改其输入,模型将如何表现,还因为它更容易计算。在本教程中,我们将完全关注第二个公式。我们还将使用更具体的术语“SHAP 值”来指代应用于机器学习模型的条件期望函数的 Shapley 值。

SHAP 值的计算可能非常复杂(它们通常是 NP 困难的),但线性模型非常简单,我们可以直接从部分依赖图中读取 SHAP 值。当我们解释预测f(x)f(x)时 ,特定特征 i 的 SHAP 值只是预期模型输出与特征值 xix_i 处的部分依赖图之间的差值:

# compute the SHAP values for the linear model
explainer = shap.Explainer(model.predict, X100)
shap_values = explainer(X)

# make a standard partial dependence plot
sample_ind = 55
shap.partial_dependence_plot(
    "medianIncome",
    model.predict,
    X100,
    model_expected_value=True,
    feature_expected_value=True,
    ice=False,
    shap_values=shap_values[sample_ind : sample_ind + 1, :],
)

在这里插入图片描述

经典的部分依赖图和 SHAP 值之间的紧密对应意味着,如果我们在整个数据集中绘制特定特征的 SHAP 值,我们将精确地绘制出该特征的部分依赖图的均值中心版本:

shap.plots.scatter(shap_values[:, "MedInc"])

在这里插入图片描述

Shapley 值的加法性质

Shapley值的一个基本属性是,当所有玩家都在场时,它们总是总结出游戏结果和没有玩家在场时的游戏结果之间的差异。对于机器学习模型,这意味着所有输入特征的SHAP值总是总结出基线(预期)模型输出和当前模型输出之间的差异,用于解释预测。最简单的方法是通过瀑布图来理解这一点,该图从我们对房价E[f(x)]E[f(x)]的背景先验预期开始,然后一次添加一个特征,直到我们达到当前模型输出f(x)f(x)

# the waterfall_plot shows how we get from shap_values.base_values to model.predict(X)[sample_ind]
shap.plots.waterfall(shap_values[sample_ind], max_display=14)

在这里插入图片描述

该图的看法是从下往上看,看通过加入哪些特征将E[f(x)]E[f(x)]变为f(x)f(x)

解释加性回归模型

线性模型的部分依赖图与 SHAP 值如此紧密的联系的原因是,模型中的每个特征都是独立于所有其他特征处理的(效应只是相加在一起)。我们可以保持这种加法性质,同时放宽直线的线性要求。这导致了众所周知的广义加法模型 (GAM) 类。虽然有很多方法可以训练这些类型的模型(例如将 XGBoost 模型设置为 depth-1),但我们将使用专为此设计的 InterpretML 可解释的提升机。

# fit a GAM model to the data
import interpret.glassbox

model_ebm = interpret.glassbox.ExplainableBoostingRegressor(interactions=0)
model_ebm.fit(X, y)

# explain the GAM model with SHAP
explainer_ebm = shap.Explainer(model_ebm.predict, X100)
shap_values_ebm = explainer_ebm(X)

# make a standard partial dependence plot with a single SHAP value overlaid
fig, ax = shap.partial_dependence_plot(
    "medianIncome",
    model_ebm.predict,
    X100,
    model_expected_value=True,
    feature_expected_value=True,
    show=False,
    ice=False,
    shap_values=shap_values_ebm[sample_ind : sample_ind + 1, :],
)

在这里插入图片描述

shap.plots.scatter(shap_values_ebm[:, "medianIncome"])

在这里插入图片描述

# the waterfall_plot shows how we get from explainer.expected_value to model.predict(X)[sample_ind]
shap.plots.waterfall(shap_values_ebm[sample_ind])

在这里插入图片描述

# the waterfall_plot shows how we get from explainer.expected_value to model.predict(X)[sample_ind]
shap.plots.beeswarm(shap_values_ebm)

在这里插入图片描述

解释非加性提升树模型

# train XGBoost model
import xgboost

model_xgb = xgboost.XGBRegressor(n_estimators=100, max_depth=2).fit(X, y)

# explain the GAM model with SHAP
explainer_xgb = shap.Explainer(model_xgb, X100)
shap_values_xgb = explainer_xgb(X)

# make a standard partial dependence plot with a single SHAP value overlaid
fig, ax = shap.partial_dependence_plot(
    "medianIncome",
    model_xgb.predict,
    X100,
    model_expected_value=True,
    feature_expected_value=True,
    show=False,
    ice=False,
    shap_values=shap_values_xgb[sample_ind : sample_ind + 1, :],
)

在这里插入图片描述

shap.plots.scatter(shap_values_xgb[:, "medianIncome"])

在这里插入图片描述

shap.plots.scatter(shap_values_xgb[:, "medianIncome"], color=shap_values_xgb[:, 4])

在这里插入图片描述

解释线性 Logistic 回归模型

# 使用adult census income数据集
X_adult, y_adult = shap.datasets.adult()

# a simple linear logistic model
model_adult = sklearn.linear_model.LogisticRegression(max_iter=10000)
model_adult.fit(X_adult, y_adult)


def model_adult_proba(x):
    return model_adult.predict_proba(x)[:, 1]


def model_adult_log_odds(x):
    p = model_adult.predict_log_proba(x)
    return p[:, 1] - p[:, 0]

# make a standard partial dependence plot
sample_ind = 18
fig, ax = shap.partial_dependence_plot(
    "Capital Gain",
    model_adult_proba,
    X_adult,
    model_expected_value=True,
    feature_expected_value=True,
    show=False,
    ice=False,
)

在这里插入图片描述

注意:在解释线性逻辑回归模型的概率输出时,这个概率并不是输入特征的线性函数。虽然逻辑回归模型的形式看起来像是一个线性组合,但最终的概率输出是通过一个非线性的 sigmoid 函数(也称为 logistic 函数)转换得到的。

如果我们使用 SHAP 来解释线性 logistic 回归模型的概率,我们会看到很强的交互作用。这是因为线性 logistic 回归模型在概率空间中不是加法的。

# compute the SHAP values for the linear model
background_adult = shap.maskers.Independent(X_adult, max_samples=100)
explainer = shap.Explainer(model_adult_proba, background_adult)
shap_values_adult = explainer(X_adult[:1000])
shap.plots.scatter(shap_values_adult[:, "Age"])

在这里插入图片描述

如果我们解释模型使用的是对数几率(log-odds)输出,我们会看到模型的输入和输出之间存在完美的线性关系。重要的是要记住你所解释的模型的单位是什么,以及解释不同的模型输出可以导致对模型行为的非常不同的理解。

# compute the SHAP values for the linear model
explainer_log_odds = shap.Explainer(model_adult_log_odds, background_adult)
shap_values_adult_log_odds = explainer_log_odds(X_adult[:1000])
shap.plots.scatter(shap_values_adult_log_odds[:, "Age"])

在这里插入图片描述

sample_ind = 18
fig, ax = shap.partial_dependence_plot(
    "Age",
    model_adult_log_odds,
    X_adult,
    model_expected_value=True,
    feature_expected_value=True,
    show=False,
    ice=False,
)

在这里插入图片描述

解释非加性提升树 Logistic 回归模型

import xgboost
# train XGBoost model
model = xgboost.XGBClassifier(n_estimators=100, max_depth=2).fit(X_adult, y_adult * 1)

# compute SHAP values
explainer = shap.Explainer(model, background_adult)
shap_values = explainer(X_adult)

# set a display version of the data to use for plotting (has string values)
shap_values.display_data = shap.datasets.adult(display=True)[0].values

shap.plots.bar(shap_values)

在这里插入图片描述

默认情况下,SHAP 条形图将采用数据集所有实例(行)上每个特征的平均绝对值。

但是平均绝对值并不是创建特征重要性全局度量的唯一方法,我们可以使用任意数量的转换。在这里,我们展示了如何使用最大绝对值来强调 Capital Gain 和 Capital Loss 特征,因为它们具有不常见但幅度较大的影响。

shap.plots.bar(shap_values.abs.max(0))

在这里插入图片描述

如果我们愿意处理更多的复杂性,我们可以使用蜂群图来总结每个特征的 SHAP 值的整个分布。

shap.plots.beeswarm(shap_values)

在这里插入图片描述

shap.plots.heatmap(shap_values[:1000])
shap.plots.scatter(shap_values[:, "Age"], color=shap_values)
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Capital Gain"])
shap.plots.scatter(shap_values[:, "Relationship"], color=shap_values)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

处理关联特征

clustering = shap.utils.hclust(X_adult, y_adult)
shap.plots.bar(shap_values, clustering=clustering)
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.8)
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=1.8)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

解释 transformers NLP 模型

import datasets
import numpy as np
import scipy as sp
import torch
import transformers

# load a BERT sentiment analysis model
tokenizer = transformers.DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
model = transformers.DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
).cuda()


# define a prediction function
def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding="max_length", max_length=500, truncation=True) for v in x]).cuda()
    outputs = model(tv)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores[:, 1])  # use one vs rest logit units
    return val


# build an explainer using a token masker
explainer = shap.Explainer(f, tokenizer)

# explain the model's predictions on IMDB reviews
imdb_train = datasets.load_dataset("imdb")["train"]
shap_values = explainer(imdb_train[:10], fixed_context=1, batch_size=2)

这个生成的图就是可交互的形式了

在这里插入图片描述