R语言求函数的偏导数

  在Machine Learning的课上,老师讲到用gradient decent的方法解logistic regression中cost function的最小值。这当中就要涉及到对cost function的求偏导数。其实在R当中可以很方便的做到这一点。

  R语言中可以使用D()来求一元函数的导数,用deriv()来求多元函数的偏导数。这两个function都在package:stats中,会在R启动时默认加载。

  下面以一个多元函数作为demo,希望通过gradient decent的方法求最小值:

$$E(u,v) = e^u + e^{2v} + e^{uv} + u^2 - 2uv + 2v^2 - 3u - 2v$$

1
2
3
4
5
6
7
8
# 创建一个expression对象
Euv = expression(exp(u) + exp(2 * v) + exp(u * v) + u^2 - 2 * u * v + 2 * v^2 -
3 * u - 2 * v)
# Euv 在u以及v方向上的偏导数,设定参数func=T,这样d_Euv是一个function,
# 可以通过d_Euv(u,v)去求值,否则,d_Euv是一个expression,要使用eval(d_Euv)去求值。
# 详细参考 ?deriv
d_Euv = deriv(Euv, c("u", "v"), func = T)
d_Euv
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
## function (u, v)
## {
## .expr1 <- exp(u)
## .expr2 <- 2 * v
## .expr3 <- exp(.expr2)
## .expr6 <- exp(u * v)
## .expr10 <- 2 * u
## .value <- .expr1 + .expr3 + .expr6 + u^2 - .expr10 * v +
## 2 * v^2 - 3 * u - .expr2
## .grad <- array(0, c(length(.value), 2L), list(NULL, c("u",
## "v")))
## .grad[, "u"] <- .expr1 + .expr6 * v + .expr10 - .expr2 -
## 3
## .grad[, "v"] <- .expr3 * 2 + .expr6 * u - .expr10 + 2 * .expr2 -
## 2
## attr(.value, "gradient") <- .grad
## .value
## }

  注意看d_Euv这个function里面的内容,d_Euv返回的.value是原函数Euv在(u,v)下的值,而不是偏导数。该function的末尾把偏导数作为attribute附加到这个value上。因此我们可以通过下面的方法把这个偏导数再提取出来。

1
2
3
derivatives = function(u, v) attributes(d_Euv(u, v))$gradient
# 求(0,0)点的偏导数
derivatives(0, 0)
1
2
## u v
## [1,] -2 0

  利用梯度下降法(gradient decent)解Euv最小值。下面只做示范用,只迭代5次,因此得到的值不一定是真正的最小值。利用fixed learning rate $\eta = 0.01$从$(u_0,v_0) = (0,0)$开始,对(u,v)进行更新:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
uv.init = as.matrix(c(u = 0, v = 0)) #把(u,v)的初始值设定为(0,0)
eta = 0.01
t = 5 # 迭代次数
# gradient decent
uv.t = uv.init
for (i in 1:t) {
# derivatives得到的是一个横向的矩阵,先转置成竖向的向量
grad = t(derivatives(uv.t[1], uv.t[2]))
uv.t = uv.t - eta * grad
}
# 经过5次梯度下降后的 (u,v)
print(uv.t)
1
2
3
## [,1]
## u 0.094140
## v 0.001789

  函数Euv在(u,v)上的值可以利用以下两种方式去解

1
2
# 1. 前面说到,d_Euv(u,v)返回值其实是函数Euv的值
print(d_Euv(uv.t[1], uv.t[2]))
1
2
3
4
## [1] 2.825
## attr(,"gradient")
## u v
## [1,] -1.715 -0.0798
1
2
3
4
# 2. 使用eval()
u = uv.t[1]
v = uv.t[2]
print(eval(Euv))
1
## [1] 2.825
喜欢就分享一下吧