如何实现文字搜图-灵析社区

小乔学算法

一、前言

上一篇介绍以图搜图的实现:juejin.cn/post/725585…
,我们利用了卷积神经网络提取特征,然后对比特征相似度,并使用向量数据库加快查找。本文我们将介绍根据文本搜索图片的实现。

首先需要知道根据文本搜索图片具体是什么问题,这里可以有两个层面。第一个则是图片中包含的文本内容,这个可以用OCR识别提取出来。第二个则是深层次的对图片描述的文本,比如红色的狗、跑步的猪、骑猪的人。这些都是对图片内容的描述,相比之下第二种要复杂得多。

二、OCR+文字搜图

OCR是指光学字符识别,也就是我们常说的文字识别。OCR的实现方式是多样的,这里使用Tesseract或者各种神经网络。OCR不是文本重点,因此这里只简单介绍其使用。详情可见:juejin.cn/post/696437…

OCR+文字搜图的原理非常简单,就是先识别文字,然后根据文字模糊查询找到相关图片即可。为了方便查询,这里需要使用数据库。

2.1 文字识别

使用pytesseract模块可以很方便实现OCR,具体代码如下:

import os, cv2

import numpy as np
import pytesseract
from tqdm import tqdm
from PIL import Image
from sqlalchemy import create_engine, String, select
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session


base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
for file in files:
try:
    image = Image.open(file)
    string = pytesseract.image_to_string(image, lang='chi_sim')
    print(file, ":", string.strip())
except Exception as e:
    pass

其中string就是识别到的文本内容。pytesseract中也提供了批量识别的接口,因为这里存在一些错误图片,因此这里不适用批量接口。

2.2 存储数据库

为了方便查询,可以把图片路径和图片中包含的文本内容存储到数据库中。这里使用sqlalchemy+sqlite,代码如下:

from sqlalchemy import create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String


class Base(DeclarativeBase):
    pass


class ImageInformation(Base):
    __tablename__ = "image_information"
    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    filepath: Mapped[str] = mapped_column(String(255))
    content: Mapped[str] = mapped_column(String(255))

    def __repr__(self) -> str:
        return f"User(id={self.id!r}, filepath={self.filepath!r}, content={self.content!r})"


engine = create_engine("sqlite:///image_search.db", echo=False)
Base.metadata.create_all(engine)

其中ImageInformation类对应我们需要创建的数据库表。创建好后,识别图片的文字,然后存储到图片数据库中:

base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
bar = tqdm(total=len(files))
for file in files:
    try:
        # 识别文字
        image = Image.open(file)
        string = pytesseract.image_to_string(image, lang='chi_sim').strip()
        
        file = file[:255] if len(file) > 255 else file
        string = string[:255] if len(string) > 255 else string
        
        # 存储数据库
        with Session(engine) as session:
            info = ImageInformation(filepath=file, content=string)
            session.add_all([info])
            session.commit()
    except Exception as e:
        pass
    bar.update(1)

这个过程会比较久。

2.3 根据文字搜索图片

完成上面的存储操作后,就可以开始根据文字查找图片了。这里只需要使用简单的数据库查询操作即可完成,代码如下,我们先把你好作为输入文本:

keyword = '你好'
w, h = 224, 224
with Session(engine) as session:
    stmt = select(ImageInformation).where(ImageInformation.content.contains(keyword)).limit(8)
    images = [cv2.resize(cv2.imread(ii.filepath), (w, h)) for ii in session.scalars(stmt)]
    if len(images) > 0:
        result = np.hstack(images)
        cv2.imwrite("result.jpg", result)
    else:
        print("没有找到结果")

下面是查询到的结果图片:

如果关键词改为喜欢,得到结果如下:


经过测试,发现在一些短文本搜索中,这种方式比较奏效,但是在长文本则经常搜索不到结果。一种改进方式是不存储文本本身,而是使用Bert等模型把文本转换成Embedding,然后存储Embedding。这样我们就不能再使用sqlite了,而需要使用向量数据库。

三、基于Transformer的改进

在前面的例子中,搜索结果非常依赖字符串匹配。比如查找鸡,只有图片中有鸡字才会被搜索到,而与坤相关的图片则查找不到。为此我们用Transformer对上面进行改进,主要思路就是先识别文字,然后把文字交给文本编码器,转换成Embedding,然后在查找时查找输入文本和Embedding的相似度,这样就可以缓解上述问题。

3.1 创建数据库

这里我们还是使用向量数据库,向量数据库有很多选择,这里使用Milvus数据库,具体使用可以参考:milvus.io/docs/instal…

首先创建数据库和集合:

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

# 创建数据库
connections.connect(host='127.0.0.1', port='19530')


def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)

    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', max_length=500, is_primary=True,
                    auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, description='filepath', max_length=512),
        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim),
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type': 'L2',
        'index_type': "IVF_FLAT",
        'params': {"nlist": 2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection


collection = create_milvus_collection('image_information', 768)

集合中主要有filepath和embedding两个字段。其中embedding有768个维度,这是由Transformer决定的,我这里选择的Transformer输出768个维度,因此这里填768。

3.2 text2vec

创建完成后,就是读取图片、识别文字、文字编码、存入数据库。其中文字编码可以使用Transformers模块或者text2vec完成,这里使用text2vec,其操作如下:

from text2vec import SentenceModel

model = SentenceModel('shibing624/text2vec-base-chinese')
embeddings = model.encode(['不要温顺地走进那个良夜'])
print(embeddings.shape)

在创建SentenceModel时传入对应的模型,然后调用model.encode方法即可。输出如下结果:

(1, 768)

其余操作则不详细解释,具体代码如下:

from text2vec import SentenceModel

model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")

base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
bar = tqdm(total=len(files))
for idx, file in enumerate(files):
    try:
        image = Image.open(file)
        string = pytesseract.image_to_string(image, lang='chi_sim').strip()
        embedding = model.encode([string])[0]
        collection.insert([
            [file],
            [embedding]
        ])
    except Exception as e:
        pass
    bar.update(1)

3.3 根据文字搜索图片

在插入数据后,直接使用数据库的查询操作即可完成搜索操作,具体代码如下:

import cv2
import numpy as np
from text2vec import SentenceModel
from pymilvus import connections, Collection

# 加载模型
model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")
# 连接数据库,加载集合
connections.connect(host='127.0.0.1', port='19530')
collection = Collection(name='image_information')
search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": 5}
collection.load()
# 用来查询的文本
keyword = "今天不开心"
embedding = model.encode([keyword])
print(embedding.shape)
# 在数据库中搜索
results = collection.search(
    data=[embedding[0]],
    anns_field='embedding',
    param=search_params,
    output_fields=['filepath'],
    limit=10,
    consistency_level="Strong"
)
collection.release()
# 展示查询结果
w, h = 224, 224
images = []
for result in results[0]:
    entity = result.entity
    filepath = entity.get('filepath')
    image = cv2.resize(cv2.imread(filepath), (w, h))
    images.append(np.array(image))
result = np.hstack(images)
cv2.imwrite("result.jpg", result)

向量数据库在查询时,可以根据向量的相似度返回查询结果。在前面我们存储了句向量,所以我们可以把查询文本转换成句向量,然后利用向量数据库的查询功能,查找相似结果。在上面代码中,我们查询“今天不开心”,这次不再是字符串层面的查询,而是句子含义层面的查询,因此可以查询的不包含这些字符的图片,下面是查询结果:

把关键词修改为“我想吃饭”后得到下面的结果:

整体效果还是非常不错的。

不过前面的结果是建立在能在图片中识别到文本的情况下,如果是我们随手拍的照片,那么就不能使用上面的方式来实现文字搜索图片。

四、基于图片含义的文字搜图

在多模态领域有许多组合模型,而我们需要的就是Image-to-Text类模型。如果要手工给图片添加画面描述会非常麻烦,因此我们选择使用Image-to-Text模型完成自动识别。

4.1 实现原理

基于图片含义的文字搜图的实现与前面基于OCR的类似,只不过需要把OCR修改为Image Captioning网络。在前面我们的流程是:

  1. 读取图片
  2. OCR识别
  3. 把识别结果转换成向量
  4. 存入数据库

现在只需要把第二步修改为使用Image Captioning生成图片描述即可。后面部分则是完全一致的。

4.2 Image Captioning

像这类输入图片,输出画面描述的任务叫做Image Captioning,用于这一任务的模型非常多。包括CNN+LSTM,Vit等都可以实现Image Captioning。两者都是一个Encoder-Decoder架构,使用CNN、Vit作为图片Encoder,将图片转换成特征图或者特征向量。然后把Encoder的输出作为Decoder的输入,并输入,然后依次生成图片描述。

以Vit为例,其结构如图:

Vit其实就是一个为图片设计的Transformer架构,在某些细节上为图片做了一些修改。

4.3 根据图片含义搜索图片

首先我们可以使用和前面相同的方式创建数据库,这里不再重复,我们复用前面的数据库image_information。然后需要修改插入数据的代码,首先来创建一个函数加载模型,并创建一个函数用于将图片转换成文本向量,代码如下:

import os
import torch
from tqdm import tqdm
from PIL import Image
from text2vec import SentenceModel
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from pymilvus import Collection

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model():
    """
    加载需要使用到的模型
    """
    sentence_model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")
    model = VisionEncoderDecoderModel.from_pretrained("bipin/image-caption-generator")
    image_processor = ViTImageProcessor.from_pretrained("bipin/image-caption-generator")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    model.to(device)
    return sentence_model, model, image_processor, tokenizer


def get_embedding(filepath):
    """
    输入图片路径,将图片转成描述向量
    """
    pixel_values = image_processor(images=[Image.open(filepath)], return_tensors="pt").pixel_values.to(device)
    output_ids = model.generate(pixel_values, num_beams=4, max_length=128)
    pred = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return sentence_model.encode(pred)

后续只需要调用get_embedding函数就可以完成图片到向量的转换。接下来就是修改插入数据的代码,具体如下:

connections.connect(host='127.0.0.1', port='19530')
collection = Collection("image_information")
collection.load()
sentence_model, model, image_processor, tokenizer = load_model()
base_path = "G:\datasets\people"
files = [os.path.join(base_path, file) for file in os.listdir(base_path)]
bar = tqdm(total=len(files))
for idx, file in enumerate(files):
    try:
        embedding = get_embedding(file)
        collection.insert([
            [file],
            [embedding]
        ])
    except Exception as e:
        pass
    bar.update(1)

最后则是搜图操作了,这个和前面是完全一样的:

search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": 5}
# 用来查询的文本
keyword = "girl"
embedding = sentence_model.encode([keyword])
# 在数据库中搜索
results = collection.search(
    data=[embedding[0]],
    anns_field='embedding',
    param=search_params,
    output_fields=['filepath'],
    limit=10,
    consistency_level="Strong"
)
collection.release()
# 展示查询结果
w, h = 224, 224
images = []
for result in results[0]:
    entity = result.entity
    filepath = entity.get('filepath')
    image = cv2.resize(cv2.imread(filepath), (w, h))
    images.append(np.array(image))
result = np.hstack(images)
cv2.imwrite("result.jpg", result)

因为这里选择的Image Captioning模型输出为英文,因此这里把英文作为关键字。这里关键字为"girl",下面是搜索结果:

因为数据库中还存储了之前的表情包,因此表情包中关于与"girl"有关的表情包也搜索出来了,比如"娘们"、"女人"等。

如果把关键字改为"smile girl",搜索结果如下:

如果图片数量足够,则可以得到一个比较好的搜索结果。

上面的结果还可以有一些改进,在Image Captioning步骤,我们只生成了一个描述。在很多情况下,这个描述不一定准确,比如下面的图片:


可以描述为“拿着话筒的姑娘”、“一个姑娘在微笑”或者“一个拿着话筒的姑娘在微笑”。因此我们可以生成多个描述,存入数据库,这样在查找时结果可以更准确。可以通过修改temperature参数来生成不同的描述:

output_ids = model.generate(pixel_values, num_beams=4, max_length=128, temperature=0.8)

当temperature小于1时,生成结果带有随机性。temperature越小,结果越随机。

阅读量:1263

点赞量:0

收藏量:0