最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

r - How to use treeshap based on mlr3 framework? - Stack Overflow

programmeradmin3浏览0评论

I am attempting to utilize the Treeshap package to calculate SHAP values. However, I encountered an error when trying to apply it to models created using the mlr3 framework. Specifically, I am unsure how to properly execute Treeshap with mlr3 models.

Could you please provide guidance or an example of how to integrate Treeshap with mlr3 models?

library(shapviz)
library(kernelshap)
library(mlr3)
library(mlr3verse)
library(mlr3learners)

library(treeshap)
library(xgboost)
data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
target <- fifa20$target


data$target = target
tsk = mlr3::TaskRegr$new(id="dd", backend = data, target = "target")
as.data.table(lrn())
learner = lrn("regr.xgboost")
mm = resample(
  task = tsk,
  learner = learner,
  resampling = rsmp("cv", folds=3),
  store_backends = T,
  store_models = T
)
unified <- unify(mm$learners[[1]], data)

Error:

unify.default(mm$learners[[1]], data): Provided model is not of type supported by treeshap.

The model object is

> mm$learners[[1]]
<LearnerRegrXgboost:regr.xgboost>: Extreme Gradient Boosting
* Model: xgb.Booster
* Parameters: nrounds=1000, nthread=1, verbose=0
* Validate: NULL
* Packages: mlr3, mlr3learners, xgboost
* Predict Types:  [response]
* Feature Types: logical, integer, numeric
* Properties: hotstart_forward, importance, internal_tuning, missings,
  validation, weights 

I am attempting to utilize the Treeshap package to calculate SHAP values. However, I encountered an error when trying to apply it to models created using the mlr3 framework. Specifically, I am unsure how to properly execute Treeshap with mlr3 models.

Could you please provide guidance or an example of how to integrate Treeshap with mlr3 models?

library(shapviz)
library(kernelshap)
library(mlr3)
library(mlr3verse)
library(mlr3learners)

library(treeshap)
library(xgboost)
data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
target <- fifa20$target


data$target = target
tsk = mlr3::TaskRegr$new(id="dd", backend = data, target = "target")
as.data.table(lrn())
learner = lrn("regr.xgboost")
mm = resample(
  task = tsk,
  learner = learner,
  resampling = rsmp("cv", folds=3),
  store_backends = T,
  store_models = T
)
unified <- unify(mm$learners[[1]], data)

Error:

unify.default(mm$learners[[1]], data): Provided model is not of type supported by treeshap.

The model object is

> mm$learners[[1]]
<LearnerRegrXgboost:regr.xgboost>: Extreme Gradient Boosting
* Model: xgb.Booster
* Parameters: nrounds=1000, nthread=1, verbose=0
* Validate: NULL
* Packages: mlr3, mlr3learners, xgboost
* Predict Types:  [response]
* Feature Types: logical, integer, numeric
* Properties: hotstart_forward, importance, internal_tuning, missings,
  validation, weights 
Share Improve this question edited 2 days ago James Z 12.3k10 gold badges27 silver badges47 bronze badges asked Feb 8 at 3:17 gen linlingen linlin 213 bronze badges
Add a comment  | 

2 Answers 2

Reset to default 1

This will work. For simplicity without parameter tuning, without validation strategy etc. Note that I am using a logarithmic response.

library(mlr3verse)
library(shapviz)

data(fifa20, package = "treeshap")

xvars <- colnames(fifa20$data)

train <- cbind(log_value = log(fifa20$target + 1e5), data.matrix(fifa20$data))

tsk <- TaskRegr$new(id = "dd", backend = data.frame(train), target = "log_value")
learner <- lrn("regr.xgboost", nrounds = 100)
learner$train(tsk)

# SHAP analysis
X_explain <- train[sample(nrow(train), 1000), learner$model$feature_names]
shap_values <- shapviz(learner$model, X_explain)
sv_importance(shap_values)
sv_importance(shap_values, kind = "bee")
top_features <- names(head(sv_importance(shap_values, kind = "no")))
sv_dependence(shap_values, top_features)

I got it! just replace mm$learners[[1]] with mm$learners[[1]]$model

unified <- unify(mm$learners[[1]]$model, data)
treeshap1 <- treeshap(unified,  data[700:800, ], verbose = 0)
treeshap1$shaps[1:3, 1:6]
plot_contribution(treeshap1, obs = 1, min_max = c(0, 16000000))
发布评论

评论列表(0)

  1. 暂无评论