参考代码
https://github.com/laiguokun/LSTNet
初始化
import math
import torch.optim as optim
class Optim(object):
def _makeOptimizer(self):
if self.method == 'sgd':
self.optimizer = optim.SGD(self.params, lr=self.lr)
elif self.method == 'adagrad':
self.optimizer = optim.Adagrad(self.params, lr=self.lr)
elif self.method == 'adadelta':
self.optimizer = optim.Adadelta(self.params, lr=self.lr)
elif self.method == 'adam':
self.optimizer = optim.Adam(self.params, lr=self.lr)
else:
raise RuntimeError("Invalid optim method: " + self.method)
def __init__(self, params, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None):
self.params = list(params) # careful: params may be a generator
self.last_ppl = None
self.lr = lr
self.max_grad_norm = max_grad_norm
self.method = method
self.lr_decay = lr_decay
self.start_decay_at = start_decay_at
self.start_decay = False
self._makeOptimizer()
梯度截断
使得梯度向量的
L
2
L2
L2 范数不会超过 self.max_grad_norm
def step(self):
# Compute gradients norm.
grad_norm = 0
for param in self.params:
grad_norm += math.pow(param.grad.data.norm(), 2)
grad_norm = math.sqrt(grad_norm)
if grad_norm > 0:
shrinkage = self.max_grad_norm / grad_norm
else:
shrinkage = 1.
for param in self.params:
if shrinkage < 1:
param.grad.data.mul_(shrinkage)
self.optimizer.step()
return grad_norm
下调学习率
只有当 达到下调时间点self.start_decay_at
或 验证集上表现没有提高ppl > self.last_ppl
时下调学习率
# decay learning rate if validation performance does not improve or we hit the start_decay_at limit
def updateLearningRate(self, ppl, epoch):
if self.start_decay_at is not None and epoch >= self.start_decay_at:
self.start_decay = True
if self.last_ppl is not None and ppl > self.last_ppl:
self.start_decay = True
if self.start_decay:
self.lr = self.lr * self.lr_decay
print("Decaying learning rate to %g" % self.lr)
#only decay for one epoch
self.start_decay = False
self.last_ppl = ppl
self._makeOptimizer()
使用时
for epoch in tqdm(range(1, args.epochs+1)):
train_loss = train(Data, Data.train[0], Data.train[1], model, criterion, optim, args.batch_size)
val_loss, val_rae, val_corr = evaluate(Data, Data.valid[0], Data.valid[1], model, evaluateL2, evaluateL1, args.batch_size);
print('| end of epoch {:3d} | time: {:5.2f}s | train_loss {:5.4f} | valid rse {:5.4f} | valid rae {:5.4f} | valid corr {:5.4f}'.format(epoch, (time.time() - epoch_start_time), train_loss, val_loss, val_rae, val_corr))
# Save the model if the validation loss is the best we've seen so far.
if val_loss < best_val:
with open(args.save, 'wb') as f:
torch.save(model, f)
best_val = val_loss
if epoch % 5 == 0:
test_acc, test_rae, test_corr = evaluate(Data, Data.test[0], Data.test[1], model, evaluateL2, evaluateL1, args.batch_size);
print ("test rse {:5.4f} | test rae {:5.4f} | test corr {:5.4f}".format(test_acc, test_rae, test_corr))
optim.updateLearningRate(val_loss, epoch) # <<< here!