forked from princewen/tensorflow_practice
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding_basic.py
More file actions
112 lines (80 loc) · 2.82 KB
/
embedding_basic.py
File metadata and controls
112 lines (80 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import tensorflow as tf
# embedding
embedding = tf.constant(
[
[0.21,0.41,0.51,0.11],
[0.22,0.42,0.52,0.12],
[0.23,0.43,0.53,0.13],
[0.24,0.44,0.54,0.14]
],dtype=tf.float32)
feature_batch = tf.constant([2,3,1,0])
get_embedding1 = tf.nn.embedding_lookup(embedding,feature_batch)
feature_batch_one_hot = tf.one_hot(feature_batch,depth=4)
get_embedding2 = tf.matmul(feature_batch_one_hot,embedding)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
embedding1,embedding2 = sess.run([get_embedding1,get_embedding2])
print(embedding1)
print(embedding2)
print(sess.run(feature_batch_one_hot))
# embedding1
embedding = tf.get_variable(name='embedding',shape=[4,4],dtype=tf.float32,initializer=tf.random_uniform_initializer)
feature_batch = tf.constant([2,3,1,0])
get_embedding1 = tf.nn.embedding_lookup(embedding,feature_batch)
feature_batch_one_hot = tf.one_hot(feature_batch,depth=4)
get_embedding2 = tf.matmul(feature_batch_one_hot,embedding)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
embedding1,embedding2 = sess.run([get_embedding1,get_embedding2])
print(embedding1)
print(embedding2)
# 单维索引
embedding = tf.constant(
[
[0.21,0.41,0.51,0.11],
[0.22,0.42,0.52,0.12],
[0.23,0.43,0.53,0.13],
[0.24,0.44,0.54,0.14]
],dtype=tf.float32)
index_a = tf.Variable([2,3,1,0])
gather_a = tf.gather(embedding, index_a)
gather_a_axis1 = tf.gather(embedding,index_a,axis=1)
b = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_b = tf.Variable([2, 4, 6, 8])
gather_b = tf.gather(b, index_b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(gather_a))
print(sess.run(gather_b))
print(sess.run(gather_a_axis1))
# 多维索引
a = tf.Variable([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
index_a = tf.Variable([2])
b = tf.get_variable(name='b',shape=[3,3,2],initializer=tf.random_normal_initializer)
index_b = tf.Variable([[0,1,1],[2,2,0]])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather_nd(a, index_a)))
print(sess.run(b))
print(sess.run(tf.gather_nd(b, index_b)))
# sparse embedding
a = tf.SparseTensor(indices=[[0, 0],[1, 2],[1,3]], values=[1, 2, 3], dense_shape=[2, 4])
b = tf.sparse_tensor_to_dense(a)
embedding = tf.constant(
[
[0.21,0.41,0.51,0.11],
[0.22,0.42,0.52,0.12],
[0.23,0.43,0.53,0.13],
[0.24,0.44,0.54,0.14]
],dtype=tf.float32)
embedding_sparse = tf.nn.embedding_lookup_sparse(embedding, sp_ids=a, sp_weights=None)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(embedding_sparse))
print(sess.run(b))
print("""
[[0.41,0.21],
[0.42,0.22],
[0.43,0.23],
[0.44,0.24]]
""")