Tensorflow字符处理
Date: 2019/07/23 Categories: 工作 Tags: Tensorflow
需要给模型加入字符级的嵌入, 但不想输入的特征有重复, 比如原先的输入是
[
"word": [
["中国", "第一", "高峰"],
["世界", "最", "高", "山峰"]
]
]
现在加入字符输入
[
"word": [
["中国", "第一", "高峰"],
["世界", "最", "高", "山峰"]
],
"char": [
[
["中", "国"],
["第", "一"],
["高", "峰"],
],
[
["世", "界"],
["最"],
["高"],
["山", "峰"]
]
]
]
但不想加入char输入, 因此需要将word分割开来:
Attempt 1: 使用strings.split
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.keras.layers as L
from tensorflow.python import autograph
lines = tf.constant([
'Hello world',
'my name is also Mark',
'Are there any other Marks here ?'
])
def split(i, PAD=10):
tensor = tf.string_split(i, delimiter='')
dense = tf.sparse.to_dense(tensor, '')
truncated = dense[:, :PAD]
paddings = [[0, 0], [0, PAD-tf.shape(truncated)[1]]]
result = tf.pad(truncated, paddings, 'CONSTANT', constant_values='')
return result
K.get_session().run(split(tf.constant(['中国']))) # array([['\xe4', '\xb8', '\xad', '\xe5', '\x9b', '\xbd', '', '', '', '']], dtype=object)
def split_unistr(s, PAD=10):
d = tf.strings.unicode_decode(s, 'UTF-8')
d = tf.expand_dims(d, 1)
x = tf.map_fn(
lambda x: tf.strings.unicode_encode(x, 'UTF-8'),
d,
dtype=tf.string,
)
x = x[:PAD]
truncated = tf.expand_dims(x, 0)
paddings = [[0, 0], [0, PAD-tf.shape(truncated)[1]]]
result = tf.pad(truncated, paddings, 'CONSTANT', constant_values='')
result = tf.squeeze(result)
return tf.reshape(result, [PAD])
r = split_unistr(tf.constant('中国美国xxxxxxxxxxxxxxxxxxx'))
print r
K.get_session().run(r)
K.get_session().run(tf.map_fn(split_unistr, lines))
batch_lines = tf.constant([
['中国', '美国', ''],
['世界', '最高', '山峰'],
['Are', 'there', 'any']
])
r = tf.map_fn(split, batch_lines)
K.get_session().run(r)
# array([[['\xe4', '\xb8', '\xad', '\xe5', '\x9b', '\xbd', '', '', '', ''],
# ['\xe7', '\xbe', '\x8e', '\xe5', '\x9b', '\xbd', '', '', '', ''],
# ['', '', '', '', '', '', '', '', '', '']],
#
# [['\xe4', '\xb8', '\x96', '\xe7', '\x95', '\x8c', '', '', '', ''],
# ['\xe6', '\x9c', '\x80', '\xe9', '\xab', '\x98', '', '', '', ''],
# ['\xe5', '\xb1', '\xb1', '\xe5', '\xb3', '\xb0', '', '', '', '']],
#
# [['A', 'r', 'e', '', '', '', '', '', '', ''],
# ['t', 'h', 'e', 'r', 'e', '', '', '', '', ''],
# ['a', 'n', 'y', '', '', '', '', '', '', '']]], dtype=object)
r = tf.map_fn(lambda x: tf.map_fn(split_unistr, x), batch_lines)
K.get_session().run(r)
#array([[['\xe4\xb8\xad', '\xe5\x9b\xbd', '', '', '', '', '', '', '', ''],
# ['\xe7\xbe\x8e', '\xe5\x9b\xbd', '', '', '', '', '', '', '', ''],
# ['', '', '', '', '', '', '', '', '', '']],
#
# [['\xe4\xb8\x96', '\xe7\x95\x8c', '', '', '', '', '', '', '', ''],
# ['\xe6\x9c\x80', '\xe9\xab\x98', '', '', '', '', '', '', '', ''],
# ['\xe5\xb1\xb1', '\xe5\xb3\xb0', '', '', '', '', '', '', '', '']],
#
# [['A', 'r', 'e', '', '', '', '', '', '', ''],
# ['t', 'h', 'e', 'r', 'e', '', '', '', '', ''],
# ['a', 'n', 'y', '', '', '', '', '', '', '']]], dtype=object)