Tensorflow字符处理

Date: 2019/07/23 Categories: 工作 Tags: Tensorflow



需要给模型加入字符级的嵌入, 但不想输入的特征有重复, 比如原先的输入是

[
    "word": [
        ["中国", "第一", "高峰"],
        ["世界", "最", "高", "山峰"]
    ]
]

现在加入字符输入

[
    "word": [
        ["中国", "第一", "高峰"],
        ["世界", "最", "高", "山峰"]
    ],
    "char": [
        [
            ["中", "国"],
            ["第", "一"],
            ["高", "峰"],
        ],
        [
            ["世", "界"],
            ["最"],
            ["高"],
            ["山", "峰"]
        ]
    ]
]

但不想加入char输入, 因此需要将word分割开来:

Attempt 1: 使用strings.split

import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.keras.layers as L
from tensorflow.python import autograph

lines = tf.constant([
    'Hello world',
    'my name is also Mark',
    'Are there any other Marks here ?'
])

def split(i, PAD=10):
    tensor = tf.string_split(i, delimiter='')
    dense = tf.sparse.to_dense(tensor, '')
    truncated = dense[:, :PAD]
    paddings = [[0, 0], [0, PAD-tf.shape(truncated)[1]]]
    result = tf.pad(truncated, paddings, 'CONSTANT', constant_values='')
    return result

K.get_session().run(split(tf.constant(['中国']))) # array([['\xe4', '\xb8', '\xad', '\xe5', '\x9b', '\xbd', '', '', '', '']], dtype=object)

def split_unistr(s, PAD=10):
    d = tf.strings.unicode_decode(s, 'UTF-8')
    d = tf.expand_dims(d, 1)
    x = tf.map_fn(
        lambda x: tf.strings.unicode_encode(x, 'UTF-8'),
        d,
        dtype=tf.string,
    )
    x = x[:PAD]
    truncated = tf.expand_dims(x, 0)
    paddings = [[0, 0], [0, PAD-tf.shape(truncated)[1]]]
    result = tf.pad(truncated, paddings, 'CONSTANT', constant_values='')
    result =  tf.squeeze(result)
    return tf.reshape(result, [PAD])

r = split_unistr(tf.constant('中国美国xxxxxxxxxxxxxxxxxxx'))
print r
K.get_session().run(r)

K.get_session().run(tf.map_fn(split_unistr, lines))

batch_lines = tf.constant([
    ['中国', '美国', ''],
    ['世界',  '最高', '山峰'],
    ['Are', 'there', 'any']
])

r = tf.map_fn(split, batch_lines)
K.get_session().run(r)

# array([[['\xe4', '\xb8', '\xad', '\xe5', '\x9b', '\xbd', '', '', '', ''],
#        ['\xe7', '\xbe', '\x8e', '\xe5', '\x9b', '\xbd', '', '', '', ''],
#        ['', '', '', '', '', '', '', '', '', '']],
#
#       [['\xe4', '\xb8', '\x96', '\xe7', '\x95', '\x8c', '', '', '', ''],
#        ['\xe6', '\x9c', '\x80', '\xe9', '\xab', '\x98', '', '', '', ''],
#        ['\xe5', '\xb1', '\xb1', '\xe5', '\xb3', '\xb0', '', '', '', '']],
#
#       [['A', 'r', 'e', '', '', '', '', '', '', ''],
#        ['t', 'h', 'e', 'r', 'e', '', '', '', '', ''],
#        ['a', 'n', 'y', '', '', '', '', '', '', '']]], dtype=object)

r = tf.map_fn(lambda x: tf.map_fn(split_unistr, x), batch_lines)
K.get_session().run(r)

#array([[['\xe4\xb8\xad', '\xe5\x9b\xbd', '', '', '', '', '', '', '', ''],
#        ['\xe7\xbe\x8e', '\xe5\x9b\xbd', '', '', '', '', '', '', '', ''],
#        ['', '', '', '', '', '', '', '', '', '']],
#
#       [['\xe4\xb8\x96', '\xe7\x95\x8c', '', '', '', '', '', '', '', ''],
#        ['\xe6\x9c\x80', '\xe9\xab\x98', '', '', '', '', '', '', '', ''],
#        ['\xe5\xb1\xb1', '\xe5\xb3\xb0', '', '', '', '', '', '', '', '']],
#
#       [['A', 'r', 'e', '', '', '', '', '', '', ''],
#        ['t', 'h', 'e', 'r', 'e', '', '', '', '', ''],
#        ['a', 'n', 'y', '', '', '', '', '', '', '']]], dtype=object)