tf.nn.embedding_lookup函數原理?

一直沒搞懂這個函數是如何將索引轉化為向量的,求大神指導


embedding_lookup(params, ids)其實就是按照ids順序返回params中的第ids行。

比如說,ids=[1,3,2],就是返回params中第1,3,2行。返回結果為由params的1,3,2行組成的tensor.

最近在看,一起學習。


假設一共有m個物體,每個物體有自己唯一的id,那麼從物體的集合到mathbb R^m有一個trivial的嵌入,就是把它映射到mathbb R^m中的標準基,這種嵌入叫做One-hot embedding/encoding.

應用中一般將物體嵌入到一個低維空間mathbb R^n(n ll m) ,只需要再compose上一個從mathbb R^mmathbb R^n的線性映射就好了。每一個n	imes m 的矩陣M都定義了mathbb R^mmathbb R^n的一個線性映射: x mapsto Mx。當x 是一個標準基向量的時候,Mx對應矩陣M中的一列,這就是對應id的向量表示。這個概念用神經網路圖來表示如下:

從id(索引)找到對應的One-hot encoding,然後紅色的weight就直接對應了輸出節點的值(注意這裡沒有activation function),也就是對應的embedding向量。


求通俗講解下tensorflow的embedding_lookup介面的意思? - 知乎


其實針對輸入是超高維,但是是one hot向量的一種特殊的全連接層的實現方法。由於輸入one hot 的原因,Wx的矩陣乘法看起來就像是取了W中對應的一列,看起來就像是在查表


import tensorflow as tf

import numpy as np

sess=tf.InteractiveSession()

embedding=tf.Variable(np.identity(5,dtype=np.int32))

input_ids=tf.placeholder(dtype=tf.int32,shape=[None])

input_embedding=tf.nn.embedding_lookup(embedding,input_ids)

tf.global_variables_initializer().run()

print (sess.run(embedding))

print (sess.run(input_embedding,feed_dict={input_ids:[1,2,3,0,3,2,1]}))

輸出:

[[1 0 0 0 0]

[0 1 0 0 0]

[0 0 1 0 0]

[0 0 0 1 0]

[0 0 0 0 1]]

[[0 1 0 0 0]

[0 0 1 0 0]

[0 0 0 1 0]

[1 0 0 0 0]

[0 0 0 1 0]

[0 0 1 0 0]

[0 1 0 0 0]]


Here"s a small example to give you a visual.

This means that the hidden layer of this model is really just operating as a lookup table. The output of the hidden layer is just the 「word vector」 for the input word.


  1. 問題本質

只是想做一次常規的線性變換而已,Z = WX + b

2. Embedding

由於輸入都是One-Hot Encoding,和矩陣相乘相當於是取出Weights矩陣中對應的那一行,所以tensoflow封裝了方法 tf.nn.embedding_lookup(params, ids)介面,更加方便的表示意思。查找params對應的ids行。

等於說變相的進行了一次矩陣相乘運算,其實就是一次線性變換。


從這裡看到的What does tf.nn.embedding_lookup function do?

作用和下面類似:

matrix = np.random.random([1024, 64]) # 64-dimensional embeddings
ids = np.array([0, 5, 17, 33])
print matrix[ids] # prints a matrix of shape [4, 64]


推薦閱讀:

深度學習演算法哪些適用於文本處理?
什麼是圖像分類的Top-5錯誤率?
自動控制、機器人、人工智慧等領域有哪些值得引進「影印版」的專業書籍?
聚類和協同過濾是什麼關係?
主題模型(topic model)到底還有沒有用,該怎麼用?

TAG:機器學習 | 深度學習DeepLearning | TensorFlow |