博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tf.concat&tf.gather&tf.gather_nd&tf.greater&tf.cast&tf.expand_dims&tf.squeeze
阅读量:5898 次
发布时间:2019-06-19

本文共 2387 字,大约阅读时间需要 7 分钟。

Tensorflow常用函数笔记


把一组向量从某一维上拼接起来,很向numpy中的Concatenate,官网例子:

t1 = [[1, 2, 3], [4, 5, 6]]t2 = [[7, 8, 9], [10, 11, 12]]tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]# tensor t3 with shape [2, 3]# tensor t4 with shape [2, 3]tf.shape(tf.concat([t3, t4], 0)) ==> [4, 3]tf.shape(tf.concat([t3, t4], 1)) ==> [2, 6]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

其实,如果是list类型的话也是可以的,只要是形似Tensor,最后tf.concat返回的还是Tensor类型

类似于数组的索引,可以把向量中某些索引值提取出来,得到新的向量,适用于要提取的索引为不连续的情况。这个函数似乎只适合在一维的情况下使用。

import tensorflow as tf a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]])index_a = tf.Variable([0,2])b = tf.Variable([1,2,3,4,5,6,7,8,9,10])index_b = tf.Variable([2,4,6,8])with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print(sess.run(tf.gather(a, index_a)))    print(sess.run(tf.gather(b, index_b)))#  [[ 1  2  3  4  5]#   [11 12 13 14 15]]#  [3 5 7 9]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

同上,但允许在多维上进行索引,例子只展示了一种很简单的用法,更复杂的用法可见官网。

import tensorflow as tf a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]])index_a = tf.Variable([[0,2], [0,4], [2,2]])with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print(sess.run(tf.gather_nd(a, index_a)))#  [ 3  5 13]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

判断函数。首先张量x和张量y的尺寸要相同,输出的tf.greater(x, y)也是一个和x,y尺寸相同的张量。如果x的某个元素比y中对应位置的元素大,则tf.greater(x, y)对应位置返回True,否则返回False。与此类似的函数还有。

import tensorflow as tf x = tf.Variable([[1,2,3], [6,7,8], [11,12,13]])y = tf.Variable([[0,1,2], [5,6,7], [10,11,12]])x1 = tf.Variable([[1,2,3], [6,7,8], [11,12,13]])y1 = tf.Variable([[10,1,2], [15,6,7], [10,21,12]])with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print(sess.run(tf.greater(x, y)))    print(sess.run(tf.greater(x1, y1)))#  [[ True  True  True]#   [ True  True  True]#   [ True  True  True]]#  [[False  True  True]#   [False  True  True]#   [ True False  True]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

转换数据类型。

a = tf.constant([0, 2, 0, 4, 2, 2], dtype='int32')print(a)# 
b = tf.cast(a, 'float32')print(b)#
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

&

增加 / 压缩张量的维度。

a = tf.constant([0, 2, 0, 4, 2, 2], dtype='int32')print(a)# 
b = tf.expand_dims(a, 0)print(b)#
print(tf.squeeze(b, 0))#
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

转载于:https://www.cnblogs.com/leebxo/p/10457441.html

你可能感兴趣的文章
Web Service简介特点,优点,缺点
查看>>
MessageContext和传输头之续一(实例演示:SIB中访问消息上下文)
查看>>
#10# SCCM规划 - 边界、边界组和站点系统 - 3
查看>>
Office365跨订阅迁移邮箱-批量导入用户PST文件
查看>>
关于Objective-c内存管理的一些笔记
查看>>
用Allegro对s3c2410的BGA封装布线
查看>>
江苏电信:SOC建设介绍
查看>>
android 应用开发:android studio使用笔记
查看>>
Centos7 YUM安装MariaDB 10.0
查看>>
Windows 2008 R2 远程桌面服务(八)远程桌面服务器安全设置
查看>>
Spring data jpa模糊查询,根据某一个字段,或者多个字段进行模糊查询
查看>>
制作QQ2011绿色版不求人
查看>>
c++filt 命令
查看>>
android控件EditText
查看>>
shell编程学习之tr
查看>>
maven多web合并项目
查看>>
Wget用法、参数解释的比较好的一个文章
查看>>
activiti学习笔记01_20130909
查看>>
学习者来报道
查看>>
学习:java设计模式—工厂模式
查看>>