標籤:

tensorflow separable_conv2d

tensorflow separable_conv2d

import numpy as npimport tensorflow as tfinput_data = tf.Variable(np.random.randint(1,5,size=(1, 4, 4,3)), dtype = np.float32 )depthwise_filter = tf.Variable(np.random.randint(1,5,size=(3, 3, 3, 3)), dtype = np.float32)#pointwise_filter = tf.Variable(np.random.randint(1,5,size=(1, 1, 9, 3)), dtype = np.float32)pointwise_filter = tf.constant(1,shape=(1,1,9,3),dtype=np.float32)sepa_out_img = tf.nn.separable_conv2d(input_data, depthwise_filter, pointwise_filter, strides = [1, 1, 1, 1], padding = VALID)depth_out_img = tf.nn.depthwise_conv2d(input_data, depthwise_filter, strides=[1,1,1,1], rate=[1,1], padding=VALID)with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) print("input_data=",sess.run(input_data)) #print("out2=",sess.run(depthwise_filter)) print("out3=",sess.run(pointwise_filter)) print ("sepa_out_img=",sess.run(sepa_out_img)) print ("depth_out_img=",sess.run(depth_out_img)) print ("type=",sess.run(tf.shape(sepa_out_img)))

推薦閱讀:

mxnet framework簡介
mxnet分散式2
為什麼選擇 MXNet?
MXNet的動態圖介面Gluon
MXNet/Gluon第五課:Gluon高級和優化演算法基礎筆記

TAG:Keras | MXNet |