人类理解世界并非只依赖文字。我们通过视觉、听觉、触觉等多种感官协同工作,构建对世界的完整认知。多模态大模型让机器首次能够像人类一样跨模态理解和生成内容。
一、CLIP:对比学习的突破
核心思想
CLIP (Contrastive Language-Image Pre-training) 通过对比学习,让图像和文本在同一语义空间中对齐。
对比损失(InfoNCE):
\[\mathcal{L} = -\frac{1}{N}\sum_{i=1}^N \log \frac{\exp(\text{sim}(I_i, T_i) / \tau)}{\sum_{j=1}^N \exp(\text{sim}(I_i, T_j) / \tau)}\]PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class CLIP(nn.Module):
def __init__(self, image_encoder, text_encoder, embed_dim=512):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.temperature = nn.Parameter(torch.tensor(0.07))
self.image_projection = nn.Linear(image_encoder.output_dim, embed_dim)
self.text_projection = nn.Linear(text_encoder.output_dim, embed_dim)
def forward(self, images, texts):
# 编码
image_features = self.image_encoder(images)
text_features = self.text_encoder(texts)
# 投影到共享空间
image_embeds = F.normalize(self.image_projection(image_features), dim=-1)
text_embeds = F.normalize(self.text_projection(text_features), dim=-1)
# 计算相似度矩阵
logits = image_embeds @ text_embeds.T / self.temperature
return logits, logits.T
def clip_loss(logits):
labels = torch.arange(logits.size(0), device=logits.device)
loss_i = F.cross_entropy(logits, labels)
loss_t = F.cross_entropy(logits.T, labels)
return (loss_i + loss_t) / 2
Zero-Shot分类
def zero_shot_classify(model, image, class_names):
texts = [f"a photo of a {c}" for c in class_names]
image_embed = model.encode_image(image)
text_embeds = model.encode_text(texts)
image_embed = F.normalize(image_embed, dim=-1)
text_embeds = F.normalize(text_embeds, dim=-1)
logits = (image_embed @ text_embeds.T) / model.temperature
probs = F.softmax(logits, dim=-1)
return probs[0]
二、BLIP:生成式视觉语言模型
三种架构模式
- Image-Text Contrastive (ITC):类似CLIP
- Image-grounded Text Generation (ITG):图像描述生成
- Image-Text Matching (ITM):二分类判断
图像描述生成
class BLIP_ITG(nn.Module):
def __init__(self, image_encoder, text_decoder):
super().__init__()
self.image_encoder = image_encoder
self.text_decoder = text_decoder
self.cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=12)
@torch.no_grad()
def generate_caption(self, image, max_length=50):
image_embeds = self.image_encoder(image.unsqueeze(0))
generated = [self.text_decoder.bos_token_id]
for _ in range(max_length):
text_embeds = self.text_decoder.embed(torch.tensor([generated]))
# 跨模态注意力
text_embeds, _ = self.cross_attention(text_embeds, image_embeds, image_embeds)
logits = self.text_decoder.lm_head(text_embeds)
next_token = logits[0, -1].argmax().item()
if next_token == self.text_decoder.eos_token_id:
break
generated.append(next_token)
return self.text_decoder.tokenizer.decode(generated)
三、GPT-4V:统一的多模态LLM
视觉Token化
方法1:Linear Projection (LLaVA)
class VisualTokenizer(nn.Module):
def __init__(self, vision_encoder, d_model):
super().__init__()
self.encoder = vision_encoder
self.projection = nn.Linear(vision_encoder.output_dim, d_model)
def forward(self, images):
features = self.encoder(images) # [batch, 196, 1024]
tokens = self.projection(features) # [batch, 196, 4096]
return tokens
方法2:Q-Former (BLIP-2)
class QFormer(nn.Module):
def __init__(self, num_queries=32, d_model=768):
super().__init__()
self.queries = nn.Parameter(torch.randn(num_queries, d_model))
self.cross_attention = nn.TransformerDecoder(...)
def forward(self, image_features):
batch_size = image_features.size(0)
queries = self.queries.unsqueeze(0).repeat(batch_size, 1, 1)
output = self.cross_attention(queries, image_features)
return output
多模态Prompt
def build_multimodal_prompt(image, text_prompt):
visual_tokens = vision_tokenizer(image)
text_tokens = text_tokenizer(text_prompt)
# 前缀方式
input_embeds = torch.cat([visual_tokens, text_tokens], dim=0)
return input_embeds
# 使用
prompt = "Describe this image in detail:"
response = multimodal_llm.generate(image=img, prompt=prompt)
四、训练技巧
硬负样本挖掘
def hard_negative_mining(image_embeds, text_embeds, num_hard=5):
sim_matrix = image_embeds @ text_embeds.T
batch_size = sim_matrix.size(0)
hard_negatives = []
for i in range(batch_size):
mask = torch.ones_like(sim_matrix[i]).bool()
mask[i] = False
neg_sims = sim_matrix[i][mask]
hard_neg_idx = neg_sims.topk(num_hard).indices
hard_negatives.append(hard_neg_idx)
return hard_negatives
温度缩放
class LearnableTemperature(nn.Module):
def __init__(self, init_value=0.07):
super().__init__()
self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / init_value)))
def forward(self, similarity):
temperature = torch.clamp(self.logit_scale.exp(), 0.01, 100)
return similarity * temperature
五、评估与基准
主要数据集
| 数据集 | 任务 | 规模 |
|---|---|---|
| MS-COCO | 图像描述/VQA | 120K图 |
| Flickr30K | 图文检索 | 30K图 |
| VQAv2 | 视觉问答 | 1.1M问题 |
| LAION-5B | 预训练 | 50亿图文对 |
评估指标
def calculate_recall(image_embeds, text_embeds, k=[1, 5, 10]):
sim_matrix = image_embeds @ text_embeds.T
recalls = {}
for K in k:
topk_indices = sim_matrix.topk(K, dim=1).indices
correct = (topk_indices == torch.arange(len(sim_matrix)).unsqueeze(1)).any(dim=1)
recalls[f'I2T_R@{K}'] = correct.float().mean().item()
return recalls
六、实战:构建简易CLIP
class SimpleCLIP(nn.Module):
def __init__(self):
super().__init__()
from torchvision.models import resnet50
from transformers import AutoModel
resnet = resnet50(pretrained=True)
self.vision_encoder = nn.Sequential(*list(resnet.children())[:-1])
self.vision_projection = nn.Linear(2048, 512)
self.text_encoder = AutoModel.from_pretrained('distilbert-base-uncased')
self.text_projection = nn.Linear(768, 512)
self.temperature = nn.Parameter(torch.tensor(0.07))
def encode_image(self, images):
features = self.vision_encoder(images).squeeze()
return F.normalize(self.vision_projection(features), dim=-1)
def encode_text(self, input_ids, attention_mask):
outputs = self.text_encoder(input_ids, attention_mask)
pooled = outputs.last_hidden_state[:, 0]
return F.normalize(self.text_projection(pooled), dim=-1)
# 训练
model = SimpleCLIP()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for images, texts in dataloader:
image_embeds = model.encode_image(images)
text_embeds = model.encode_text(texts['input_ids'], texts['attention_mask'])
logits = (image_embeds @ text_embeds.T) / model.temperature
labels = torch.arange(len(images), device=logits.device)
loss = (F.cross_entropy(logits, labels) +
F.cross_entropy(logits.T, labels)) / 2
optimizer.zero_grad()
loss.backward()
optimizer.step()
七、总结
多模态AI的核心:
- 对齐是关键:CLIP用对比学习对齐图文
- 架构选择:编码器 vs 编解码器 vs 统一LLM
- 数据规模:大规模预训练是基础
前沿方向:
- ImageBind:六模态统一空间
- Any-to-Any生成:任意模态输入输出
- 世界模型:物理推理与3D理解
多模态大模型正在成为AI的下一个制高点!🚀
💬 交流与讨论
⚠️ 尚未完成 Giscus 配置。请在
_config.yml中设置repo_id与category_id后重新部署,即可启用升级后的评论系统。配置完成后,评论区将自动支持 Markdown 代码高亮与 LaTeX 数学公式渲染,访客回复会同步到 GitHub Discussions,并具备通知功能。