上一篇介绍以图搜图的实现:juejin.cn/post/725585…
,我们利用了卷积神经网络提取特征,然后对比特征相似度,并使用向量数据库加快查找。本文我们将介绍根据文本搜索图片的实现。
首先需要知道根据文本搜索图片具体是什么问题,这里可以有两个层面。第一个则是图片中包含的文本内容,这个可以用OCR识别提取出来。第二个则是深层次的对图片描述的文本,比如红色的狗、跑步的猪、骑猪的人。这些都是对图片内容的描述,相比之下第二种要复杂得多。
OCR是指光学字符识别,也就是我们常说的文字识别。OCR的实现方式是多样的,这里使用Tesseract或者各种神经网络。OCR不是文本重点,因此这里只简单介绍其使用。详情可见:juejin.cn/post/696437…
OCR+文字搜图的原理非常简单,就是先识别文字,然后根据文字模糊查询找到相关图片即可。为了方便查询,这里需要使用数据库。
使用pytesseract模块可以很方便实现OCR,具体代码如下:
import os, cv2
import numpy as np
import pytesseract
from tqdm import tqdm
from PIL import Image
from sqlalchemy import create_engine, String, select
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session
base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
for file in files:
try:
image = Image.open(file)
string = pytesseract.image_to_string(image, lang='chi_sim')
print(file, ":", string.strip())
except Exception as e:
pass
其中string就是识别到的文本内容。pytesseract中也提供了批量识别的接口,因为这里存在一些错误图片,因此这里不适用批量接口。
为了方便查询,可以把图片路径和图片中包含的文本内容存储到数据库中。这里使用sqlalchemy+sqlite,代码如下:
from sqlalchemy import create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String
class Base(DeclarativeBase):
pass
class ImageInformation(Base):
__tablename__ = "image_information"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
filepath: Mapped[str] = mapped_column(String(255))
content: Mapped[str] = mapped_column(String(255))
def __repr__(self) -> str:
return f"User(id={self.id!r}, filepath={self.filepath!r}, content={self.content!r})"
engine = create_engine("sqlite:///image_search.db", echo=False)
Base.metadata.create_all(engine)
其中ImageInformation类对应我们需要创建的数据库表。创建好后,识别图片的文字,然后存储到图片数据库中:
base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
bar = tqdm(total=len(files))
for file in files:
try:
# 识别文字
image = Image.open(file)
string = pytesseract.image_to_string(image, lang='chi_sim').strip()
file = file[:255] if len(file) > 255 else file
string = string[:255] if len(string) > 255 else string
# 存储数据库
with Session(engine) as session:
info = ImageInformation(filepath=file, content=string)
session.add_all([info])
session.commit()
except Exception as e:
pass
bar.update(1)
这个过程会比较久。
完成上面的存储操作后,就可以开始根据文字查找图片了。这里只需要使用简单的数据库查询操作即可完成,代码如下,我们先把你好作为输入文本:
keyword = '你好'
w, h = 224, 224
with Session(engine) as session:
stmt = select(ImageInformation).where(ImageInformation.content.contains(keyword)).limit(8)
images = [cv2.resize(cv2.imread(ii.filepath), (w, h)) for ii in session.scalars(stmt)]
if len(images) > 0:
result = np.hstack(images)
cv2.imwrite("result.jpg", result)
else:
print("没有找到结果")
下面是查询到的结果图片:
如果关键词改为喜欢,得到结果如下:
经过测试,发现在一些短文本搜索中,这种方式比较奏效,但是在长文本则经常搜索不到结果。一种改进方式是不存储文本本身,而是使用Bert等模型把文本转换成Embedding,然后存储Embedding。这样我们就不能再使用sqlite了,而需要使用向量数据库。
在前面的例子中,搜索结果非常依赖字符串匹配。比如查找鸡,只有图片中有鸡字才会被搜索到,而与坤相关的图片则查找不到。为此我们用Transformer对上面进行改进,主要思路就是先识别文字,然后把文字交给文本编码器,转换成Embedding,然后在查找时查找输入文本和Embedding的相似度,这样就可以缓解上述问题。
这里我们还是使用向量数据库,向量数据库有很多选择,这里使用Milvus数据库,具体使用可以参考:milvus.io/docs/instal…
首先创建数据库和集合:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
# 创建数据库
connections.connect(host='127.0.0.1', port='19530')
def create_milvus_collection(collection_name, dim):
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
fields = [
FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', max_length=500, is_primary=True,
auto_id=True),
FieldSchema(name='filepath', dtype=DataType.VARCHAR, description='filepath', max_length=512),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim),
]
schema = CollectionSchema(fields=fields, description='reverse image search')
collection = Collection(name=collection_name, schema=schema)
# create IVF_FLAT index for collection.
index_params = {
'metric_type': 'L2',
'index_type': "IVF_FLAT",
'params': {"nlist": 2048}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
collection = create_milvus_collection('image_information', 768)
集合中主要有filepath和embedding两个字段。其中embedding有768个维度,这是由Transformer决定的,我这里选择的Transformer输出768个维度,因此这里填768。
创建完成后,就是读取图片、识别文字、文字编码、存入数据库。其中文字编码可以使用Transformers模块或者text2vec完成,这里使用text2vec,其操作如下:
from text2vec import SentenceModel
model = SentenceModel('shibing624/text2vec-base-chinese')
embeddings = model.encode(['不要温顺地走进那个良夜'])
print(embeddings.shape)
在创建SentenceModel时传入对应的模型,然后调用model.encode方法即可。输出如下结果:
(1, 768)
其余操作则不详细解释,具体代码如下:
from text2vec import SentenceModel
model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")
base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
bar = tqdm(total=len(files))
for idx, file in enumerate(files):
try:
image = Image.open(file)
string = pytesseract.image_to_string(image, lang='chi_sim').strip()
embedding = model.encode([string])[0]
collection.insert([
[file],
[embedding]
])
except Exception as e:
pass
bar.update(1)
在插入数据后,直接使用数据库的查询操作即可完成搜索操作,具体代码如下:
import cv2
import numpy as np
from text2vec import SentenceModel
from pymilvus import connections, Collection
# 加载模型
model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")
# 连接数据库,加载集合
connections.connect(host='127.0.0.1', port='19530')
collection = Collection(name='image_information')
search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": 5}
collection.load()
# 用来查询的文本
keyword = "今天不开心"
embedding = model.encode([keyword])
print(embedding.shape)
# 在数据库中搜索
results = collection.search(
data=[embedding[0]],
anns_field='embedding',
param=search_params,
output_fields=['filepath'],
limit=10,
consistency_level="Strong"
)
collection.release()
# 展示查询结果
w, h = 224, 224
images = []
for result in results[0]:
entity = result.entity
filepath = entity.get('filepath')
image = cv2.resize(cv2.imread(filepath), (w, h))
images.append(np.array(image))
result = np.hstack(images)
cv2.imwrite("result.jpg", result)
向量数据库在查询时,可以根据向量的相似度返回查询结果。在前面我们存储了句向量,所以我们可以把查询文本转换成句向量,然后利用向量数据库的查询功能,查找相似结果。在上面代码中,我们查询“今天不开心”,这次不再是字符串层面的查询,而是句子含义层面的查询,因此可以查询的不包含这些字符的图片,下面是查询结果:
把关键词修改为“我想吃饭”后得到下面的结果:
整体效果还是非常不错的。
不过前面的结果是建立在能在图片中识别到文本的情况下,如果是我们随手拍的照片,那么就不能使用上面的方式来实现文字搜索图片。
在多模态领域有许多组合模型,而我们需要的就是Image-to-Text类模型。如果要手工给图片添加画面描述会非常麻烦,因此我们选择使用Image-to-Text模型完成自动识别。
基于图片含义的文字搜图的实现与前面基于OCR的类似,只不过需要把OCR修改为Image Captioning网络。在前面我们的流程是:
现在只需要把第二步修改为使用Image Captioning生成图片描述即可。后面部分则是完全一致的。
像这类输入图片,输出画面描述的任务叫做Image Captioning,用于这一任务的模型非常多。包括CNN+LSTM,Vit等都可以实现Image Captioning。两者都是一个Encoder-Decoder架构,使用CNN、Vit作为图片Encoder,将图片转换成特征图或者特征向量。然后把Encoder的输出作为Decoder的输入,并输入,然后依次生成图片描述。
以Vit为例,其结构如图:
Vit其实就是一个为图片设计的Transformer架构,在某些细节上为图片做了一些修改。
首先我们可以使用和前面相同的方式创建数据库,这里不再重复,我们复用前面的数据库image_information。然后需要修改插入数据的代码,首先来创建一个函数加载模型,并创建一个函数用于将图片转换成文本向量,代码如下:
import os
import torch
from tqdm import tqdm
from PIL import Image
from text2vec import SentenceModel
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from pymilvus import Collection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
"""
加载需要使用到的模型
"""
sentence_model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")
model = VisionEncoderDecoderModel.from_pretrained("bipin/image-caption-generator")
image_processor = ViTImageProcessor.from_pretrained("bipin/image-caption-generator")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model.to(device)
return sentence_model, model, image_processor, tokenizer
def get_embedding(filepath):
"""
输入图片路径,将图片转成描述向量
"""
pixel_values = image_processor(images=[Image.open(filepath)], return_tensors="pt").pixel_values.to(device)
output_ids = model.generate(pixel_values, num_beams=4, max_length=128)
pred = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return sentence_model.encode(pred)
后续只需要调用get_embedding函数就可以完成图片到向量的转换。接下来就是修改插入数据的代码,具体如下:
connections.connect(host='127.0.0.1', port='19530')
collection = Collection("image_information")
collection.load()
sentence_model, model, image_processor, tokenizer = load_model()
base_path = "G:\datasets\people"
files = [os.path.join(base_path, file) for file in os.listdir(base_path)]
bar = tqdm(total=len(files))
for idx, file in enumerate(files):
try:
embedding = get_embedding(file)
collection.insert([
[file],
[embedding]
])
except Exception as e:
pass
bar.update(1)
最后则是搜图操作了,这个和前面是完全一样的:
search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": 5}
# 用来查询的文本
keyword = "girl"
embedding = sentence_model.encode([keyword])
# 在数据库中搜索
results = collection.search(
data=[embedding[0]],
anns_field='embedding',
param=search_params,
output_fields=['filepath'],
limit=10,
consistency_level="Strong"
)
collection.release()
# 展示查询结果
w, h = 224, 224
images = []
for result in results[0]:
entity = result.entity
filepath = entity.get('filepath')
image = cv2.resize(cv2.imread(filepath), (w, h))
images.append(np.array(image))
result = np.hstack(images)
cv2.imwrite("result.jpg", result)
因为这里选择的Image Captioning模型输出为英文,因此这里把英文作为关键字。这里关键字为"girl",下面是搜索结果:
因为数据库中还存储了之前的表情包,因此表情包中关于与"girl"有关的表情包也搜索出来了,比如"娘们"、"女人"等。
如果把关键字改为"smile girl",搜索结果如下:
如果图片数量足够,则可以得到一个比较好的搜索结果。
上面的结果还可以有一些改进,在Image Captioning步骤,我们只生成了一个描述。在很多情况下,这个描述不一定准确,比如下面的图片:
可以描述为“拿着话筒的姑娘”、“一个姑娘在微笑”或者“一个拿着话筒的姑娘在微笑”。因此我们可以生成多个描述,存入数据库,这样在查找时结果可以更准确。可以通过修改temperature参数来生成不同的描述:
output_ids = model.generate(pixel_values, num_beams=4, max_length=128, temperature=0.8)
当temperature小于1时,生成结果带有随机性。temperature越小,结果越随机。
阅读量:1263
点赞量:0
收藏量:0