[TOC]

  • 名称:Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
  • 论文:https://arxiv.org/abs/2103.06495
  • 会议:AAAI2020
  • Github: https://github.com/FangShancheng/ABINet

ABINet(Attention-based Bidirectional Network)是一种用于场景文本识别(Scene Text Recognition, STR)的深度学习模型。它在处理复杂背景、噪声干扰以及弯曲或倾斜文本时表现出色。ABINet 的核心创新点是引入了 双向注意力机制迭代优化策略 ,从而显著提升了文本识别的准确性和鲁棒性。

以下是 ABINet 的详细解析,包括其架构设计、工作原理、优势和实现细节。


1. ABINet 的背景

问题

  • 自然场景中的文本通常具有复杂的形状(如弯曲、倾斜等),并且背景可能包含大量噪声。
  • 传统的基于分类的方法(如 CRNN)难以处理长文本和复杂背景。
  • 现有的注意力机制通常只关注单向信息流(从左到右或从右到左),忽略了上下文的双向依赖关系。

解决方案

  • ABINet 提出了一个端到端的框架,通过以下创新点解决上述问题:
    1. 双向注意力机制 :同时建模从左到右和从右到左的上下文信息,本质就是Bert
    2. 迭代优化策略 :通过多次迭代逐步优化预测结果。
    3. 视觉和语言联合建模 :结合视觉特征和语言先验知识,提升识别性能。

2. ABINet 的架构

image-20250422170653342

ABINet 的整体架构包括以下几个关键部分:

2.1 主干网络(Backbone)

  • 主干网络用于提取输入图像的视觉特征。
  • 常用的主干网络包括 ResNet、MobileNet 等。
  • 输出是一个特征图,表示每个位置的视觉特征。

2.2 视觉注意力模块(Visual Attention Module)

主干网络提取特征图,位置编码初始化Q,K是特征图经过Unet生成,V是特征图本身。学习每个位置编码关注的文本区域。

image-20250422171519011
  • 视觉注意力模块基于编码器-解码器结构,利用注意力机制生成字符序列。
  • 包括两个方向的注意力:
    • 正向注意力 :从左到右生成字符序列。
    • 反向注意力 :从右到左生成字符序列。

2.3 语言模型模块(Language Model Module)

  • Q由位置编码初始化,K和V由视觉概率图提供
  • 完形填空式预测:只看上下文,自己不进行attention,所有增加了Mask
  • 视觉模型只负责提取特征,后续矫正只更新文本模型
image-20250422171618159
  • 语言模型模块基于双向 Transformer,利用语言先验知识对字符序列进行校正。
  • 它能够捕捉字符之间的依赖关系,从而提高识别的准确性。

2.4 迭代优化策略

  • ABINet 采用多次迭代的方式逐步优化预测结果:
    1. 第一次迭代:基于视觉特征生成初步的字符序列。
    2. 后续迭代:结合语言模型对初步结果进行校正。
    3. 最终输出:后续经过多次迭代后的最优预测结果。
    4. 迭代只有第一次和第二次用到视觉特征,之后的输入都只来自融合模型

3. ABINet 的工作原理

3.1 双向注意力机制

  • 正向注意力和反向注意力分别建模从左到右和从右到左的上下文信息。
  • 通过融合两个方向的注意力结果,获得更全面的上下文信息。

3.2 语言模型的作用

  • 语言模型模块基于 Transformer 架构,利用双向注意力机制捕捉字符之间的依赖关系。
  • 它能够纠正初步预测中的错误(如拼写错误或顺序错误)。

3.3 迭代优化

  • 每次迭代都会将前一次的预测结果作为输入,结合语言模型进行进一步优化。
  • 这种迭代策略能够逐步逼近最优解,尤其适用于长文本和复杂场景。

4. ABINet 的优势

  1. 双向上下文建模
    • 通过双向注意力机制,充分利用了上下文信息,提高了识别的准确性。
  2. 语言先验知识
    • 结合语言模型模块,能够捕捉字符之间的依赖关系,从而纠正初步预测中的错误。
  3. 迭代优化
    • 多次迭代逐步优化预测结果,特别适合处理长文本和复杂场景。
  4. 鲁棒性
    • 能够处理弯曲、倾斜、遮挡等复杂场景中的文本。

5. 代码实现

以下是一个基于 PyTorch 的简化实现示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
import torch.nn.functional as F

class ABINet(nn.Module):
def __init__(self, backbone, num_classes, max_iter=3):
super(ABINet, self).__init__()
self.backbone = backbone
self.visual_attention = VisualAttentionModule(num_classes)
self.language_model = LanguageModelModule(num_classes)
self.max_iter = max_iter

def forward(self, x):
# 主干网络提取特征
features = self.backbone(x)

# 初始化预测结果
predictions = None

# 迭代优化
for _ in range(self.max_iter):
# 视觉注意力模块生成初步预测
if predictions is None:
visual_output = self.visual_attention(features)
else:
visual_output = self.visual_attention(features, predictions)

# 语言模型模块校正预测结果
predictions = self.language_model(visual_output)

return predictions

class VisualAttentionModule(nn.Module):
def __init__(self, num_classes):
super(VisualAttentionModule, self).__init__()
self.attention = nn.MultiheadAttention(embed_dim=512, num_heads=8)
self.fc = nn.Linear(512, num_classes)

def forward(self, features, prev_predictions=None):
# 注意力机制生成字符序列
attention_output, _ = self.attention(features, features, features)
logits = self.fc(attention_output)
return logits

class LanguageModelModule(nn.Module):
def __init__(self, num_classes):
super(LanguageModelModule, self).__init__()
self.transformer = nn.Transformer(d_model=512, num_encoder_layers=2, num_decoder_layers=2)
self.fc = nn.Linear(512, num_classes)

def forward(self, inputs):
# 双向 Transformer 校正字符序列
transformer_output = self.transformer(inputs, inputs)
logits = self.fc(transformer_output)
return logits

# 示例用法
if __name__ == "__main__":
# 定义一个简单的主干网络
class SimpleBackbone(nn.Module):
def __init__(self):
super(SimpleBackbone, self).__init__()
self.conv = nn.Conv2d(3, 512, kernel_size=3, padding=1)

def forward(self, x):
return self.conv(x)

# 初始化模型
backbone = SimpleBackbone()
model = ABINet(backbone, num_classes=37) # 假设有 36 个字符 + 1 个空白符

# 输入图像
image = torch.randn(1, 3, 32, 128) # [B, C, H, W]

# 前向传播
predictions = model(image)

print("预测结果形状:", predictions.shape)

6. 运行结果

假设输入图像的尺寸为 [1, 3, 32, 128],则输出结果如下:

1
预测结果形状: torch.Size([1, L, 37])

其中:

  • L 是字符序列的最大长度。
  • 37 是词汇表的大小(包括空白符)。

7. 注意事项

  1. 主干网络的选择
    • 主干网络的选择会影响特征提取的质量,建议使用预训练的深度网络(如 ResNet、MobileNet)。
  2. 迭代次数的设置
    • 迭代次数需要根据具体任务进行调整,过多的迭代可能导致过拟合。
  3. 语言模型的设计
    • 语言模型的架构需要结合具体任务,确保能够捕捉字符之间的依赖关系。
  4. 后处理
    • 在推理阶段,需要对预测结果进行解码(如使用贪婪搜索或束搜索)生成最终的文本。

8. 总结

ABINet 是一种高效的场景文本识别模型,通过引入双向注意力机制和迭代优化策略,解决了传统方法中上下文信息不足和长文本识别困难的问题。它的设计简单且高效,能够处理复杂的自然场景文本识别任务。通过结合视觉特征和语言先验知识,ABINet 在精度和鲁棒性之间取得了良好的平衡,成为场景文本识别领域的经典方法之一。