import torch
import torch.nn as nn
from torch.nn import functional as F
# 基本卷积块, 长宽不变,in_channels -> out_channels
class Conv(nn.Module):
def __init__(self, C_in, C_out):
super(Conv, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(C_in, C_out, 3, 1, 1),
nn.BatchNorm2d(C_out),
# 防止过拟合
nn.Dropout(0.3),
nn.LeakyReLU(),
nn.Conv2d(C_out, C_out, 3, 1, 1),
nn.BatchNorm2d(C_out),
# 防止过拟合
nn.Dropout(0.4),
nn.LeakyReLU(),
)
def forward(self, x):
return self.layer(x)
# 下采样模块,长宽下采样2倍,通道数不变
class DownSampling(nn.Module):
def __init__(self, C):
super(DownSampling, self).__init__()
self.Down = nn.Sequential(
# 使用卷积进行2倍的下采样,通道数不变
nn.Conv2d(C, C, 3, 2, 1),
nn.LeakyReLU()
)
def forward(self, x):
return s
Unet代码实现(PyTorch)
于 2023-04-22 14:31:23 首次发布