维比特解码-动态规划


  • 直接上代码
import numpy as np

# 转移概率,单纯用了等概率 >> b分词开始, e分词结束, m分词的中间,s仅仅单字分开
zy = {'be': 0.5,  # be 开始可以到结束
      'bm': 0.5,  # bm 开始到中间
      'eb': 0.5,  # em 结束可以到开始
      'es': 0.5,  # es 结束可以到仅仅单字
      'me': 0.5,  # me 中间可以到结束
      'mm': 0.5,  # mm 中间可以到中间
      'sb': 0.5,  # sb 仅仅单字到开始
      'ss': 0.5   # ss 仅仅单字到单字
      }
zy = {i: np.log(zy[i]) for i in zy.keys()}
def viterbi(nodes):
    paths = {'b': nodes[0]['b'], 's': nodes[0]['s']}
    for l in range(1, len(nodes)):
        paths_ = paths.copy()
        paths = {}
        for i in nodes[l].keys():
            nows = {}
            for j in paths_.keys():
                if j[-1] + i in zy.keys():
                    nows[j + i] = paths_[j] + nodes[l][i] + zy[j[-1] + i]
            k = np.argmax(list(nows.values()))
            paths[list(nows.keys())[k]] = list(nows.values())[k]
    return list(paths.keys())[np.argmax(list(paths.values()))]
''' # 搜索路径
{'b': -0.015575318, 's': -6.057363} # 开始只能是b和s开头
{'ss': -12.232245180559945, 'sb': -10.078559880559945, 'bm': -6.001635898559945, 'be': -0.7548268105599453} # 只能由b,s转移过来
{'bes': -1.8031197911198906, 'beb': -3.1734596911198905, 'bmm': -13.754842079119891, 'bme': -8.81519967911989} # 
{'bess': -2.793391871679836, 'besb': -3.930443871679836, 'bebm': -10.070449871679836, 'bebe': -7.958071871679836}
{'besss': -10.28113575223978, 'bessb': -3.5089358942397815, 'besbm': -9.212326652239781, 'besbe': -9.146066052239782}
{'besbes': -15.647376232799727, 'besbeb': -12.655429032799727, 'bessbm': -7.489156074799727, 'bessbe': -4.307664504799726}
{'bessbes': -5.027113643359672, 'bessbeb': -13.851752685359672, 'bessbmm': -18.539550255359668, 'bessbme': -11.840306155359672}
{'bessbess': -6.4368791839196176, 'bessbesb': -6.391419973919618, 'bessbebm': -22.522033665919615, 'bessbebe': -23.422544865919615}
{'bessbesss': -7.691225734479563, 'bessbessb': -13.506323364479563, 'bessbesbm': -13.615205854479564, 'bessbesbe': -7.937145494479563}
{'bessbessss': -9.722234715039509, 'bessbesssb': -8.695739685039507, 'bessbessbm': -19.816604145039506, 'bessbessbe': -20.702765945039506}
{'bessbesssss': -15.996908695599453, 'bessbessssb': -11.647404195599455, 'bessbesssbm': -9.937732865599452, 'bessbesssbe': -11.454922365599453} # 取最大的分数对应的路径
'''
s = '人们常说生活是一部教科书'  # "分词的文本"
# lstm >> Dense 预测得到的每个字的label的分数
nodes = [{'s': -6.057363, 'b': -0.015575318, 'm': -5.010491, 'e': -5.04824},
         {'s': -5.481735, 'b': -3.3280497, 'm': -5.2929134, 'e': -0.046104312},
         {'s': -0.3551458, 'b': -1.7254857, 'm': -7.060059, 'e': -2.1204166},
         {'s': -0.2971249, 'b': -1.4341769, 'm': -6.203843, 'e': -4.091465},
         {'s': -6.7945967, 'b': -0.022396842, 'm': -4.5887356, 'e': -4.522475},
         {'s': -5.808163, 'b': -2.8162158, 'm': -3.287073, 'e': -0.10558143},
         {'s': -0.026301958, 'b': -8.850941, 'm': -10.357247, 'e': -3.6580029},
         {'s': -0.71661836, 'b': -0.67115915, 'm': -7.9771338, 'e': -8.877645},
         {'s': -0.56119937, 'b': -6.376297, 'm': -6.5306387, 'e': -0.85257834},
         {'s': -1.3378618, 'b': -0.31136677, 'm': -5.6171336, 'e': -6.5032954},
         {'s': -5.5815268, 'b': -1.2320223, 'm': -0.548846, 'e': -2.0660355},
         {'s': -4.801048, 'b': -6.7394543, 'm': -2.2255623, 'e': -0.12511909}]
# 解码
t = viterbi(nodes)
print(t)  # bessbesssbme
# 规则进行分词得到结果
words = []
for i in range(len(s)):
    if t[i] in ['s', 'b']:
        words.append(s[i])
    else:
        words[-1] += s[i]
print(words)  # ['人们', '常', '说', '生活', '是', '一', '部', '教科书']