模型图:
import numpy as np
import random
import math
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class CNN_LSTM_ATT_DNN_Net(nn.Module):
def __init__(self):
# 模型是cnn + lstm + lstm + Dense
super(CNN_LSTM_ATT_DNN_Net, self).__init__()
# 初始参数-------
self.input_size=31
# LSTM
self.ce