-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMNIST.cs
More file actions
94 lines (73 loc) · 2.45 KB
/
MNIST.cs
File metadata and controls
94 lines (73 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#define MNIST
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using TorchSharpDataLoaderExample;
using TorchSharp;
using static System.Linq.Enumerable;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
#if MNIST
using var trainDataset = new MNISTReader("dataset", "train");
using var testDataset = new MNISTReader("dataset", "t10k");
using var train = new DataLoader(trainDataset, 128, true, CPU);
using var test = new DataLoader(testDataset, 256, false, CPU);
var model = new Model();
var criterion = functional.cross_entropy_loss();
var optimizer = optim.Adam(model.parameters(), learningRate: 0.01);
Console.WriteLine("Initialized");
foreach (var x in train) { }
foreach(var epoch in Range(1, 20)){
var avg_cost = 0.0;
var idx = 0;
foreach (var d in train)
{
Console.Write($"\r{idx++} / {train.Count}");
optimizer.zero_grad();
var hypothesis = model.forward(d["data"]);
var cost = criterion(hypothesis, d["label"]);
cost.backward();
optimizer.step();
avg_cost += cost.ToSingle() / train.Count;
}
Console.WriteLine("\r" + avg_cost);
}
class Model : Module
{
private Module conv1 = Conv2d(1, 32, 3);
private Module conv2 = Conv2d(32, 64, 3);
private Module fc1 = Linear(9216, 128);
private Module fc2 = Linear(128, 10);
// These don't have any parameters, so the only reason to instantiate
// them is performance, since they will be used over and over.
private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 });
private Module relu1 = ReLU();
private Module relu2 = ReLU();
private Module relu3 = ReLU();
private Module dropout1 = Dropout(0.25);
private Module dropout2 = Dropout(0.5);
private Module flatten = Flatten();
private Module logsm = LogSoftmax(1);
public Model() : base("h")
{
RegisterComponents();
}
public override Tensor forward(Tensor input)
{
var l11 = conv1.forward(input);
var l12 = relu1.forward(l11);
var l21 = conv2.forward(l12);
var l22 = relu2.forward(l21);
var l23 = pool1.forward(l22);
var l24 = dropout1.forward(l23);
var x = flatten.forward(l24);
var l31 = fc1.forward(x);
var l32 = relu3.forward(l31);
var l33 = dropout2.forward(l32);
var l41 = fc2.forward(l33);
return logsm.forward(l41);
}
}
#endif