forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVariableTest.cs
More file actions
153 lines (134 loc) · 4.47 KB
/
VariableTest.cs
File metadata and controls
153 lines (134 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using static Tensorflow.Python;
namespace TensorFlowNET.UnitTest
{
[TestClass]
public class VariableTest
{
[TestMethod]
public void Initializer()
{
var x = tf.Variable(10, name: "x");
using (var session = tf.Session())
{
session.run(x.initializer);
var result = session.run(x);
Assert.AreEqual(10, (int)result);
}
}
[TestMethod]
public void StringVar()
{
var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars);
var mammal2 = tf.Variable("Tiger");
}
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/variable_scope
/// how to create a new variable
/// </summary>
[TestMethod]
public void VarCreation()
{
tf.Graph().as_default();
with(tf.variable_scope("foo"), delegate
{
with(tf.variable_scope("bar"), delegate
{
var v = tf.get_variable("v", new TensorShape(1));
Assert.AreEqual(v.name, "foo/bar/v:0");
});
});
}
/// <summary>
/// how to reenter a premade variable scope safely
/// </summary>
[TestMethod]
public void ReenterVariableScope()
{
tf.Graph().as_default();
variable_scope vs = null;
with(tf.variable_scope("foo"), v => vs = v);
// Re-enter the variable scope.
with(tf.variable_scope(vs, auxiliary_name_scope: false), v =>
{
var vs1 = (VariableScope)v;
// Restore the original name_scope.
with(tf.name_scope(vs1.original_name_scope), delegate
{
var v1 = tf.get_variable("v", new TensorShape(1));
Assert.AreEqual(v1.name, "foo/v:0");
var c1 = tf.constant(new int[] { 1 }, name: "c");
Assert.AreEqual(c1.name, "foo/c:0");
});
});
}
[TestMethod]
public void ScalarVar()
{
var x = tf.constant(3, name: "x");
var y = tf.Variable(x + 1, name: "y");
var model = tf.global_variables_initializer();
using (var session = tf.Session())
{
session.run(model);
int result = session.run(y);
Assert.AreEqual(result, 4);
}
}
[TestMethod]
public void Assign1()
{
with(tf.Graph().as_default(), graph =>
{
var variable = tf.Variable(31, name: "tree");
var init = tf.global_variables_initializer();
var sess = tf.Session(graph);
sess.run(init);
var result = sess.run(variable);
Assert.IsTrue((int)result == 31);
var assign = variable.assign(12);
result = sess.run(assign);
Assert.IsTrue((int)result == 12);
});
}
[TestMethod]
public void Assign2()
{
var v1 = tf.Variable(10.0f, name: "v1"); //tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
var inc_v1 = v1.assign(v1 + 1.0f);
// Add an op to initialize the variables.
var init_op = tf.global_variables_initializer();
with(tf.Session(), sess =>
{
sess.run(init_op);
// o some work with the model.
inc_v1.op.run();
});
}
/// <summary>
/// https://databricks.com/tensorflow/variables
/// </summary>
[TestMethod]
public void Add()
{
int result = 0;
Tensor x = tf.Variable(10, name: "x");
var init_op = tf.global_variables_initializer();
using (var session = tf.Session())
{
session.run(init_op);
for(int i = 0; i < 5; i++)
{
x = x + 1;
result = session.run(x);
print(result);
}
}
Assert.AreEqual(15, result);
}
}
}