PyTorch是很受欢迎的机器学习库,对应库名为torch;
'''使用torch识别简单正态曲线'''
import torch
import numpy as np
from scipy.stats import norm
def create_norm_zero(x_num):
x_base = np.linspace(-6, 6, 100)*x_num
y_base = np.linspace(-3, 3, 100)
y_show = norm.pdf(x_base)/np.max(norm.pdf(y_base))
return [y_show*20]
def create_norm_data():
a = []
norm_data = []
for i in range(1):
num = (i + 1) * 0.23 + 1
norm_data.append(create_norm_zero(num))
a.append(norm_data)
a = np.array(a)
return a
a = torch.ones(1,1,1,100)
b = create_norm_data()
c = torch.from_numpy(b)
x = torch.cat((c, a), 1).type(torch.FloatTensor)
y_t = torch.ones(1,1,50)
y_f = torch.zeros(1,1,50)
y = torch.cat((y_t, y_f), -1).type(torch.LongTensor)
上段为生成训练数据部分,下段为定义模型、训练模型部分;