Pytorch Ignite 使用方法
下载 pip install ignite
官方网址:https://pytorch.org/ignite/concepts.html
Engine
该框架的本质是class Engine
,它是一种抽象形式,它在提供的数据上循环给定的次数,执行处理函数并返回结果:
while epoch < max_epochs:
# run an epoch on data
data_iter = iter(data)
while True:
try:
batch = next(data_iter)
output = process_function(batch)
iter_counter += 1
except StopIteration:
data_iter = iter(data)
if iter_counter == epoch_length:
break
因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。同样,可以使用在验证数据集上运行一次并计算指标的引擎来完成模型评估。
例如,用于监督任务的模型训练器:
def train_step(trainer, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
训练步骤的输出类型(即在上面的示例中loss.item()
)不受限制。训练步骤功能可以返回用户想要的一切。输出设置为,trainer.state.output
并且可以进一步用于任何类型的处理。
默认情况下,epoch_length 长度由
len(data)
定义。但是,用户也可以手动将epoch_length 长度定义为要循环的多次迭代。这样,输入数据可以是迭代器。trainer.run(data, max_epochs=100, epoch_length=200)
如果data是长度未知的有限数据迭代器(对于用户),
epoch_length
则可以省略参数,并且在耗尽数据迭代器时将自动确定参数。
任何复杂度的训练逻辑都可以使用train_step
方法进行编码,并且可以使用此方法来构造训练器。
函数batch
中的train_step
参数是用户定义的,可以包含单个迭代所需的任何数据。
# 定义模型参数
model_1 = ...
model_2 = ...
# 定义优化器
optimizer_1 = ...
optimizer_2 = ...
#
criterion_1 = ...
criterion_2 = ...
# ...
def train_step(trainer, batch):
data_1 = batch["data_1"]
data_2 = batch["data_2"]
# ...
model_1.train()
optimizer_1.zero_grad()
loss_1 = forward_pass(data_1, model_1, criterion_1)
loss_1