背景
文本数据的大型预训练语言模型具有不受约束的输出空间;在每个解码步骤中,它们可以产生数万个token中的任何一个。当对SQL等受约束的形式语言进行Fine-tune时,这些模型通常会生成无效代码,使其不可用。本文提出了PICARD模型,一种通过增量解析约束语言模型的自回归解码器的方法。PICARD通过在每个解码步骤拒绝不可接受的token输出来帮助找到有效的输出序列。
方法
-
使用了一种新的用于约束解码的增量解析方法PICARD(Parsing Incrementally for Constrained Auto-Regressive Decoding)来解决这些方法的开销问题。PICARD与任何现有的自回归语言模型解码器和词汇(包括但不限于那些预先训练好的大型Transformer)兼容,而且它不需要非常大的波束大小。PICARD完全不存在于模型的预训练或微调中,它可以在推理时轻松且可选地启用。PICARD直接操作语言模型的输出,即Text-to-SQL任务中生成的SQL语句
-
PICARD将模型预测分数与现有的贪婪算法和波束搜索算法简单地结合在一起,用于语言模型的自回归解码。它的参数是当前翻译输出的token id和对于每个词汇表里的token,模型的语言建模头预测的log-softmax分数。PICARD还可以访问SQL schema的信息,特别是关于表和列的名称以及关于哪个列包含在哪个表中的信息。在每一个生成步骤中,PICARD首先将预测限制在最高k个概率token上,然后给那些没有通过PICARD大量检查的token一个−∞ 的分数(参见图2)
(1). Lexing
在词法分析模式下,PICARD只在词法级别上检查输出。它试图将部分的、detokenize后的模型输出转换为由空格分隔的单个SQL关键字(如select)、标点(如()、操作符(如+和-)、字面值(如SQL条件中的字符串和数字值)以及标识符(如别名、表、和列,而对这些词汇项出现的顺序不敏感。通过这样做,PICARD可以检测关键字中的拼写错误,或者拒绝对给定SQL模式无效的表和列名。例如,考虑以下问题(来自于Spider的验证集中的dog_kernnels数据库):
"What are the email, cell phone and home phone of each professional?"
我们的Fine-tune后的T5-Large模型预测输出为
select email_address, cell_phone, home_phone from professionals
而实际上ground truth是选择 “cell_number” ,而不是无效的 “cell_phone” 列。在词法分析模式下,PICARD可以捕获并避免这一错误。
(2). Parsing without Guards
在Lexing模式之上最低的解析模式(即Parsing without Guards)中,PICARD在语法级别检查输出。PICARD试图将detokenize后的模型输出解析为表示SQL查询的抽象语法树(AST)的数据结构。与词法分析模式相反,关键字和子句出现的顺序现在很重要。PICARD可以拒绝无效的查询结构,例如查找 from 子句中的缺失或子句和关键字的错误顺序。它还可以检测SQL表达式组合的一系列问题:首先,如果PICARD匹配到了一个 tid.cid (注:tid即table_id,cid即column_id)的模式,但是实际上 tid 这个table里并没有 cid 这个column,那么该条输出将被拒绝。其次,如果PICARD首先配到了一个 alias.cid (alias即别名,即我们经常在SQL语句中使用的 FROM table_name as T1 …… )的模式,然后匹配到了一个 tid as alias 的模式,但实际上 tid 这个table里并没有 cid 这个column,那么该条输出也将被拒绝。对于绑定到表别名的子查询,也存在一个等价的规则。最后,PICARD禁止在相同选择范围内重复绑定表别名,但允许隐藏在周围范围内定义的别名。这可能发生在嵌套SQL查询中。
(3). Parsing with Guards
在最高的解析模式中,PICARD在组装SQL的AST时进行额外的分析,称为Guards。如果PICARD匹配到 tid.cid 或者 alias.cid ,然后Guards需要 tid 或 alias 必须在from子句中出现。而且,alias被限制为解析到其中包含cid列的表或子查询。如果PICARD匹配cid模式,那么另一个Guards要求最终将恰好包含具有该id的列的一个表引入作用域。这些Guards被急切地执行,以便在尽可能早的时间将无效的输出快速失效,逐出beam。Guards在解析from子句之后才开始执行。