简介

欢迎来到 LightGBM 的世界,它是一个高效的梯度提升实现 (Ke et al. 2017)。

本文档将引导您了解其基本用法。它将展示如何基于 `bank` 数据集的一个子集 (Moro, Cortez, and Rita 2014) 构建一个简单的二元分类模型。您将使用“age”和“balance”这两个输入特征来预测客户是否订阅了定期存款。

数据集

数据集如下所示。

data(bank, package = "lightgbm")

bank[1L:5L, c("y", "age", "balance")]
#>         y   age balance
#>    <char> <int>   <int>
#> 1:     no    30    1787
#> 2:     no    33    4789
#> 3:     no    35    1350
#> 4:     no    30    1476
#> 5:     no    59       0

# Distribution of the response
table(bank$y)
#> 
#>   no  yes 
#> 4000  521

模型训练

LightGBM 的 R 包提供了两个用于训练模型的函数

  • lgb.train(): 这是主要的训练逻辑。它提供了全面的灵活性,但需要一个由 lgb.Dataset() 函数创建的 Dataset 对象。
  • lightgbm(): 更简单,但灵活性较低。可以直接传递数据,无需使用 lgb.Dataset()

使用 lightgbm() 函数

第一步,您需要将数据转换为数值类型。之后,就可以使用 lightgbm() 函数拟合模型了。

# Numeric response and feature matrix
y <- as.numeric(bank$y == "yes")
X <- data.matrix(bank[, c("age", "balance")])

# Train
fit <- lightgbm(
  data = X
  , label = y
  , params = list(
    num_leaves = 4L
    , learning_rate = 1.0
    , objective = "binary"
  )
  , nrounds = 10L
  , verbose = -1L
)

# Result
summary(predict(fit, X))
#>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#> 0.01192 0.07370 0.09871 0.11593 0.14135 0.65796

看起来成功了!并且预测结果确实是介于 0 到 1 之间的概率值。

使用 lgb.train() 函数

或者,您可以使用更灵活的接口 lgb.train()。在这里,您需要额外一步,通过 LightGBM 的数据 API lgb.Dataset() 来准备 yX。参数作为命名列表传递给 lgb.train()

# Data interface
dtrain <- lgb.Dataset(X, label = y)

# Parameters
params <- list(
  objective = "binary"
  , num_leaves = 4L
  , learning_rate = 1.0
)

# Train
fit <- lgb.train(
  params
  , data = dtrain
  , nrounds = 10L
  , verbose = -1L
)

试试看!如果遇到问题,请访问 LightGBM 的 文档 获取更多详细信息。

参考文献

Ke, Guolin, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, Weidong Ma, Qiwei Ye, and Tie-Yan Liu. 2017. “LightGBM: A Highly Efficient Gradient Boosting Decision Tree.” In Advances in Neural Information Processing Systems 30 (NIPS 2017).

Moro, Sérgio, Paulo Cortez, and Paulo Rita. 2014. “A Data-Driven Approach to Predict the Success of Bank Telemarketing.” Decision Support Systems 62: 22–31.