Pytorch Ignite 使用方法

Pytorch Ignite 使用方法

下载 pip install ignite
官方网址:https://pytorch.org/ignite/concepts.html

Engine

该框架的本质是class Engine,它是一种抽象形式,它在提供的数据上循环给定的次数,执行处理函数并返回结果:

  while epoch < max_epochs:
    # run an epoch on data
    data_iter = iter(data)
    while True:
        try:
            batch = next(data_iter)
            output = process_function(batch)
            iter_counter += 1
        except StopIteration:
            data_iter = iter(data)

        if iter_counter == epoch_length:
            break

因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。同样,可以使用在验证数据集上运行一次并计算指标的引擎来完成模型评估。

例如,用于监督任务的模型训练器:

def train_step(trainer, batch):
    model.train()
    optimizer.zero_grad()
    x, y = prepare_batch(batch)
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

训练步骤的输出类型(即在上面的示例中loss.item())不受限制。训练步骤功能可以返回用户想要的一切。输出设置为,trainer.state.output并且可以进一步用于任何类型的处理。

默认情况下,epoch_length 长度由len(data)定义。但是,用户也可以手动将epoch_length 长度定义为要循环的多次迭代。这样,输入数据可以是迭代器。

trainer.run(data, max_epochs=100, epoch_length=200)

如果data是长度未知的有限数据迭代器(对于用户),epoch_length则可以省略参数,并且在耗尽数据迭代器时将自动确定参数。

任何复杂度的训练逻辑都可以使用train_step方法进行编码,并且可以使用此方法来构造训练器。

函数batch中的train_step参数是用户定义的,可以包含单个迭代所需的任何数据。

# 定义模型参数
model_1 = ...
model_2 = ...
# 定义优化器
optimizer_1 = ...
optimizer_2 = ...
# 
criterion_1 = ...
criterion_2 = ...
# ...

def train_step(trainer, batch):

    data_1 = batch["data_1"]
    data_2 = batch["data_2"]
    # ...

    model_1.train()
    optimizer_1.zero_grad()
    loss_1 = forward_pass(data_1, model_1, criterion_1)
    loss_1
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值