Skip to content

Commit ad73ea2

Browse files
fehiepsifacebook-github-bot
authored andcommitted
Add strong Wolfe line search for lbfgs (#8824)
Summary: This pull request adds a line search for lbfgs. "strong Wolfe" is the default line search method in [minFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html) and it is also recommended in the [Numerical Optimization](https://www.springer.com/gp/book/9780387303031) book. The implementation is based on four sources: + https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html + https://www.springer.com/gp/book/9780387303031 Algorithms 3.5, 3.6, formula 3.59 + https://github.com/torch/optim/blob/master/lswolfe.lua + https://github.com/torch/optim/blob/master/polyinterp.lua The 'lua' version is based on an old version of `minFunc`, which has been updated in 2012. I made a couple of small changes based on the updated version. Due to that, the test of comparing with `.lua` version is not consistent (that's is the reason I changed a learning rate in the test). Pull Request resolved: #8824 Differential Revision: D15783067 Pulled By: vincentqb fbshipit-source-id: 5316d9088233981120376d79c7869d5f97e51b69
1 parent 2c91ba3 commit ad73ea2

File tree

2 files changed

+217
-21
lines changed

2 files changed

+217
-21
lines changed

test/test_optim.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,10 @@ def test_lbfgs(self):
447447
lambda weight, bias: optim.LBFGS([weight, bias]),
448448
ignore_multidevice=True
449449
)
450+
self._test_basic_cases(
451+
lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_Wolfe"),
452+
ignore_multidevice=True
453+
)
450454

451455
@unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
452456
def test_lbfgs_return_type(self):

torch/optim/lbfgs.py

Lines changed: 213 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,171 @@
33
from .optimizer import Optimizer
44

55

6+
def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
7+
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
8+
# Compute bounds of interpolation area
9+
if bounds is not None:
10+
xmin_bound, xmax_bound = bounds
11+
else:
12+
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
13+
14+
# Code for most common case: cubic interpolation of 2 points
15+
# w/ function and derivative values for both
16+
# Solution in this case (where x2 is the farthest point):
17+
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
18+
# d2 = sqrt(d1^2 - g1*g2);
19+
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
20+
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
21+
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
22+
d2_square = d1 ** 2 - g1 * g2
23+
if d2_square >= 0:
24+
d2 = d2_square.sqrt()
25+
if x1 <= x2:
26+
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
27+
else:
28+
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
29+
return min(max(min_pos, xmin_bound), xmax_bound)
30+
else:
31+
return (xmin_bound + xmax_bound) / 2.
32+
33+
34+
def _strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9,
35+
max_ls=25):
36+
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
37+
d_norm = d.abs().max()
38+
g = g.clone()
39+
# evaluate objective and gradient using initial step
40+
f_new, g_new = obj_func(x, t, d)
41+
ls_func_evals = 1
42+
gtd_new = g_new.dot(d)
43+
44+
# bracket an interval containing a point satisfying the Wolfe criteria
45+
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
46+
done = False
47+
ls_iter = 0
48+
while ls_iter < max_ls:
49+
# check conditions
50+
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
51+
bracket = [t_prev, t]
52+
bracket_f = [f_prev, f_new]
53+
bracket_g = [g_prev, g_new.clone()]
54+
bracket_gtd = [gtd_prev, gtd_new]
55+
break
56+
57+
if abs(gtd_new) <= -c2 * gtd:
58+
bracket = [t]
59+
bracket_f = [f_new]
60+
bracket_g = [g_new]
61+
done = True
62+
break
63+
64+
if gtd_new >= 0:
65+
bracket = [t_prev, t]
66+
bracket_f = [f_prev, f_new]
67+
bracket_g = [g_prev, g_new.clone()]
68+
bracket_gtd = [gtd_prev, gtd_new]
69+
break
70+
71+
# interpolate
72+
min_step = t + 0.01 * (t - t_prev)
73+
max_step = t * 10
74+
tmp = t
75+
t = _cubic_interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new,
76+
bounds=(min_step, max_step))
77+
78+
# next step
79+
t_prev = tmp
80+
f_prev = f_new
81+
g_prev = g_new.clone()
82+
gtd_prev = gtd_new
83+
f_new, g_new = obj_func(x, t, d)
84+
ls_func_evals += 1
85+
gtd_new = g_new.dot(d)
86+
ls_iter += 1
87+
88+
# reached max number of iterations?
89+
if ls_iter == max_ls:
90+
bracket = [0, t]
91+
bracket_f = [f, f_new]
92+
bracket_g = [g, g_new]
93+
94+
# zoom phase: we now have a point satisfying the criteria, or
95+
# a bracket around it. We refine the bracket until we find the
96+
# exact point satisfying the criteria
97+
insuf_progress = False
98+
# find high and low points in bracket
99+
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
100+
while not done and ls_iter < max_ls:
101+
# compute new trial value
102+
t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
103+
bracket[1], bracket_f[1], bracket_gtd[1])
104+
105+
# test that we are making sufficient progress:
106+
# in case `t` is so close to boundary, we mark that we are making
107+
# insufficient progress, and if
108+
# + we have made insufficient progress in the last step, or
109+
# + `t` is at one of the boundary,
110+
# we will move `t` to a position which is `0.1 * len(bracket)`
111+
# away from the nearest boundary point.
112+
eps = 0.1 * (max(bracket) - min(bracket))
113+
if min(max(bracket) - t, t - min(bracket)) < eps:
114+
# interpolation close to boundary
115+
if insuf_progress or t >= max(bracket) or t <= min(bracket):
116+
# evaluate at 0.1 away from boundary
117+
if abs(t - max(bracket)) < abs(t - min(bracket)):
118+
t = max(bracket) - eps
119+
else:
120+
t = min(bracket) + eps
121+
insuf_progress = False
122+
else:
123+
insuf_progress = True
124+
else:
125+
insuf_progress = False
126+
127+
# Evaluate new point
128+
f_new, g_new = obj_func(x, t, d)
129+
ls_func_evals += 1
130+
gtd_new = g_new.dot(d)
131+
ls_iter += 1
132+
133+
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
134+
# Armijo condition not satisfied or not lower than lowest point
135+
bracket[high_pos] = t
136+
bracket_f[high_pos] = f_new
137+
bracket_g[high_pos] = g_new.clone()
138+
bracket_gtd[high_pos] = gtd_new
139+
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
140+
else:
141+
if abs(gtd_new) <= -c2 * gtd:
142+
# Wolfe conditions satisfied
143+
done = True
144+
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
145+
# old high becomes new low
146+
bracket[high_pos] = bracket[low_pos]
147+
bracket_f[high_pos] = bracket_f[low_pos]
148+
bracket_g[high_pos] = bracket_g[low_pos]
149+
bracket_gtd[high_pos] = bracket_gtd[low_pos]
150+
151+
# new point becomes new low
152+
bracket[low_pos] = t
153+
bracket_f[low_pos] = f_new
154+
bracket_g[low_pos] = g_new.clone()
155+
bracket_gtd[low_pos] = gtd_new
156+
157+
# line-search bracket is so small
158+
if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
159+
break
160+
161+
# return stuff
162+
t = bracket[low_pos]
163+
f_new = bracket_f[low_pos]
164+
g_new = bracket_g[low_pos]
165+
return f_new, g_new, t, ls_func_evals
166+
167+
6168
class LBFGS(Optimizer):
7-
"""Implements L-BFGS algorithm.
169+
"""Implements L-BFGS algorithm, heavily inspired by `minFunc
170+
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`.
8171
9172
.. warning::
10173
This optimizer doesn't support per-parameter options and parameter
@@ -30,6 +193,7 @@ class LBFGS(Optimizer):
30193
tolerance_change (float): termination tolerance on function
31194
value/parameter changes (default: 1e-9).
32195
history_size (int): update history size (default: 100).
196+
line_search_fn (str): either 'strong_Wolfe' or None (default: None).
33197
"""
34198

35199
def __init__(self, params, lr=1, max_iter=20, max_eval=None,
@@ -58,11 +222,11 @@ def _gather_flat_grad(self):
58222
views = []
59223
for p in self._params:
60224
if p.grad is None:
61-
view = p.data.new(p.data.numel()).zero_()
62-
elif p.grad.data.is_sparse:
63-
view = p.grad.data.to_dense().view(-1)
225+
view = p.new(p.numel()).zero_()
226+
elif p.grad.is_sparse:
227+
view = p.grad.to_dense().view(-1)
64228
else:
65-
view = p.grad.data.view(-1)
229+
view = p.grad.view(-1)
66230
views.append(view)
67231
return torch.cat(views, 0)
68232

@@ -75,6 +239,20 @@ def _add_grad(self, step_size, update):
75239
offset += numel
76240
assert offset == self._numel()
77241

242+
def _clone_param(self):
243+
return [p.clone() for p in self._params]
244+
245+
def _set_param(self, params_data):
246+
for p, pdata in zip(self._params, params_data):
247+
p.data.copy_(pdata)
248+
249+
def _directional_evaluate(self, closure, x, t, d):
250+
self._add_grad(t, d)
251+
loss = float(closure())
252+
flat_grad = self._gather_flat_grad()
253+
self._set_param(x)
254+
return loss, flat_grad
255+
78256
def step(self, closure):
79257
"""Performs a single optimization step.
80258
@@ -106,16 +284,18 @@ def step(self, closure):
106284
state['func_evals'] += 1
107285

108286
flat_grad = self._gather_flat_grad()
109-
abs_grad_sum = flat_grad.abs().sum()
287+
opt_cond = flat_grad.abs().max() <= tolerance_grad
110288

111-
if abs_grad_sum <= tolerance_grad:
289+
# optimal condition
290+
if opt_cond:
112291
return orig_loss
113292

114293
# tensors cached in state (for tracing)
115294
d = state.get('d')
116295
t = state.get('t')
117296
old_dirs = state.get('old_dirs')
118297
old_stps = state.get('old_stps')
298+
ro = state.get('ro')
119299
H_diag = state.get('H_diag')
120300
prev_flat_grad = state.get('prev_flat_grad')
121301
prev_loss = state.get('prev_loss')
@@ -134,6 +314,7 @@ def step(self, closure):
134314
d = flat_grad.neg()
135315
old_dirs = []
136316
old_stps = []
317+
ro = []
137318
H_diag = 1
138319
else:
139320
# do lbfgs update (update memory)
@@ -146,10 +327,12 @@ def step(self, closure):
146327
# shift history by one (limited-memory)
147328
old_dirs.pop(0)
148329
old_stps.pop(0)
330+
ro.pop(0)
149331

150332
# store new direction/step
151333
old_dirs.append(y)
152334
old_stps.append(s)
335+
ro.append(1. / ys)
153336

154337
# update scale of initial Hessian approximation
155338
H_diag = ys / y.dot(y) # (y*y)
@@ -158,15 +341,10 @@ def step(self, closure):
158341
# multiplied by the gradient
159342
num_old = len(old_dirs)
160343

161-
if 'ro' not in state:
162-
state['ro'] = [None] * history_size
344+
if 'al' not in state:
163345
state['al'] = [None] * history_size
164-
ro = state['ro']
165346
al = state['al']
166347

167-
for i in range(num_old):
168-
ro[i] = 1. / old_dirs[i].dot(old_stps[i])
169-
170348
# iteration in L-BFGS loop collapsed to use just one buffer
171349
q = flat_grad.neg()
172350
for i in range(num_old - 1, -1, -1):
@@ -191,18 +369,32 @@ def step(self, closure):
191369
############################################################
192370
# reset initial guess for step size
193371
if state['n_iter'] == 1:
194-
t = min(1., 1. / abs_grad_sum) * lr
372+
t = min(1., 1. / flat_grad.abs().sum()) * lr
195373
else:
196374
t = lr
197375

198376
# directional derivative
199377
gtd = flat_grad.dot(d) # g * d
200378

379+
# directional derivative is below tolerance
380+
if gtd > -tolerance_change:
381+
break
382+
201383
# optional line search: user function
202384
ls_func_evals = 0
203385
if line_search_fn is not None:
204386
# perform line search, using user function
205-
raise RuntimeError("line search function is not supported yet")
387+
if line_search_fn != "strong_Wolfe":
388+
raise RuntimeError("only 'strong_Wolfe' is supported")
389+
else:
390+
x_init = self._clone_param()
391+
392+
def obj_func(x, t, d):
393+
return self._directional_evaluate(closure, x, t, d)
394+
loss, flat_grad, t, ls_func_evals = _strong_Wolfe(obj_func, x_init, t, d,
395+
loss, flat_grad, gtd)
396+
self._add_grad(t, d)
397+
opt_cond = flat_grad.abs().max() <= tolerance_grad
206398
else:
207399
# no line search, simply move with fixed-step
208400
self._add_grad(t, d)
@@ -212,7 +404,7 @@ def step(self, closure):
212404
# no use to re-evaluate that function here
213405
loss = float(closure())
214406
flat_grad = self._gather_flat_grad()
215-
abs_grad_sum = flat_grad.abs().sum()
407+
opt_cond = flat_grad.abs().max() <= tolerance_grad
216408
ls_func_evals = 1
217409

218410
# update func eval
@@ -228,13 +420,12 @@ def step(self, closure):
228420
if current_evals >= max_eval:
229421
break
230422

231-
if abs_grad_sum <= tolerance_grad:
232-
break
233-
234-
if gtd > -tolerance_change:
423+
# optimal condition
424+
if opt_cond:
235425
break
236426

237-
if d.mul(t).abs_().sum() <= tolerance_change:
427+
# lack of progress
428+
if d.mul(t).abs().max() <= tolerance_change:
238429
break
239430

240431
if abs(loss - prev_loss) < tolerance_change:
@@ -244,6 +435,7 @@ def step(self, closure):
244435
state['t'] = t
245436
state['old_dirs'] = old_dirs
246437
state['old_stps'] = old_stps
438+
state['ro'] = ro
247439
state['H_diag'] = H_diag
248440
state['prev_flat_grad'] = prev_flat_grad
249441
state['prev_loss'] = prev_loss

0 commit comments

Comments
 (0)