多模态大模型技术

从CLIP到GPT-4V的视觉语言理解

Posted by Feng Yu on October 20, 2024

人类理解世界并非只依赖文字。我们通过视觉、听觉、触觉等多种感官协同工作,构建对世界的完整认知。多模态大模型让机器首次能够像人类一样跨模态理解和生成内容。


一、CLIP:对比学习的突破

核心思想

CLIP (Contrastive Language-Image Pre-training) 通过对比学习,让图像和文本在同一语义空间中对齐。

对比损失(InfoNCE)

\[\mathcal{L} = -\frac{1}{N}\sum_{i=1}^N \log \frac{\exp(\text{sim}(I_i, T_i) / \tau)}{\sum_{j=1}^N \exp(\text{sim}(I_i, T_j) / \tau)}\]

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class CLIP(nn.Module):
    def __init__(self, image_encoder, text_encoder, embed_dim=512):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.temperature = nn.Parameter(torch.tensor(0.07))
        
        self.image_projection = nn.Linear(image_encoder.output_dim, embed_dim)
        self.text_projection = nn.Linear(text_encoder.output_dim, embed_dim)
    
    def forward(self, images, texts):
        # 编码
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(texts)
        
        # 投影到共享空间
        image_embeds = F.normalize(self.image_projection(image_features), dim=-1)
        text_embeds = F.normalize(self.text_projection(text_features), dim=-1)
        
        # 计算相似度矩阵
        logits = image_embeds @ text_embeds.T / self.temperature
        
        return logits, logits.T

def clip_loss(logits):
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.T, labels)
    return (loss_i + loss_t) / 2

Zero-Shot分类

def zero_shot_classify(model, image, class_names):
    texts = [f"a photo of a {c}" for c in class_names]
    
    image_embed = model.encode_image(image)
    text_embeds = model.encode_text(texts)
    
    image_embed = F.normalize(image_embed, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    
    logits = (image_embed @ text_embeds.T) / model.temperature
    probs = F.softmax(logits, dim=-1)
    
    return probs[0]

二、BLIP:生成式视觉语言模型

三种架构模式

  1. Image-Text Contrastive (ITC):类似CLIP
  2. Image-grounded Text Generation (ITG):图像描述生成
  3. Image-Text Matching (ITM):二分类判断

图像描述生成

class BLIP_ITG(nn.Module):
    def __init__(self, image_encoder, text_decoder):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_decoder = text_decoder
        self.cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=12)
    
    @torch.no_grad()
    def generate_caption(self, image, max_length=50):
        image_embeds = self.image_encoder(image.unsqueeze(0))
        
        generated = [self.text_decoder.bos_token_id]
        for _ in range(max_length):
            text_embeds = self.text_decoder.embed(torch.tensor([generated]))
            
            # 跨模态注意力
            text_embeds, _ = self.cross_attention(text_embeds, image_embeds, image_embeds)
            
            logits = self.text_decoder.lm_head(text_embeds)
            next_token = logits[0, -1].argmax().item()
            
            if next_token == self.text_decoder.eos_token_id:
                break
            generated.append(next_token)
        
        return self.text_decoder.tokenizer.decode(generated)

三、GPT-4V:统一的多模态LLM

视觉Token化

方法1:Linear Projection (LLaVA)

class VisualTokenizer(nn.Module):
    def __init__(self, vision_encoder, d_model):
        super().__init__()
        self.encoder = vision_encoder
        self.projection = nn.Linear(vision_encoder.output_dim, d_model)
    
    def forward(self, images):
        features = self.encoder(images)  # [batch, 196, 1024]
        tokens = self.projection(features)  # [batch, 196, 4096]
        return tokens

方法2:Q-Former (BLIP-2)

class QFormer(nn.Module):
    def __init__(self, num_queries=32, d_model=768):
        super().__init__()
        self.queries = nn.Parameter(torch.randn(num_queries, d_model))
        self.cross_attention = nn.TransformerDecoder(...)
    
    def forward(self, image_features):
        batch_size = image_features.size(0)
        queries = self.queries.unsqueeze(0).repeat(batch_size, 1, 1)
        output = self.cross_attention(queries, image_features)
        return output

多模态Prompt

def build_multimodal_prompt(image, text_prompt):
    visual_tokens = vision_tokenizer(image)
    text_tokens = text_tokenizer(text_prompt)
    
    # 前缀方式
    input_embeds = torch.cat([visual_tokens, text_tokens], dim=0)
    
    return input_embeds

# 使用
prompt = "Describe this image in detail:"
response = multimodal_llm.generate(image=img, prompt=prompt)

四、训练技巧

硬负样本挖掘

def hard_negative_mining(image_embeds, text_embeds, num_hard=5):
    sim_matrix = image_embeds @ text_embeds.T
    
    batch_size = sim_matrix.size(0)
    hard_negatives = []
    
    for i in range(batch_size):
        mask = torch.ones_like(sim_matrix[i]).bool()
        mask[i] = False
        
        neg_sims = sim_matrix[i][mask]
        hard_neg_idx = neg_sims.topk(num_hard).indices
        hard_negatives.append(hard_neg_idx)
    
    return hard_negatives

温度缩放

class LearnableTemperature(nn.Module):
    def __init__(self, init_value=0.07):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / init_value)))
    
    def forward(self, similarity):
        temperature = torch.clamp(self.logit_scale.exp(), 0.01, 100)
        return similarity * temperature

五、评估与基准

主要数据集

数据集 任务 规模
MS-COCO 图像描述/VQA 120K图
Flickr30K 图文检索 30K图
VQAv2 视觉问答 1.1M问题
LAION-5B 预训练 50亿图文对

评估指标

def calculate_recall(image_embeds, text_embeds, k=[1, 5, 10]):
    sim_matrix = image_embeds @ text_embeds.T
    
    recalls = {}
    for K in k:
        topk_indices = sim_matrix.topk(K, dim=1).indices
        correct = (topk_indices == torch.arange(len(sim_matrix)).unsqueeze(1)).any(dim=1)
        recalls[f'I2T_R@{K}'] = correct.float().mean().item()
    
    return recalls

六、实战:构建简易CLIP

class SimpleCLIP(nn.Module):
    def __init__(self):
        super().__init__()
        
        from torchvision.models import resnet50
        from transformers import AutoModel
        
        resnet = resnet50(pretrained=True)
        self.vision_encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.vision_projection = nn.Linear(2048, 512)
        
        self.text_encoder = AutoModel.from_pretrained('distilbert-base-uncased')
        self.text_projection = nn.Linear(768, 512)
        
        self.temperature = nn.Parameter(torch.tensor(0.07))
    
    def encode_image(self, images):
        features = self.vision_encoder(images).squeeze()
        return F.normalize(self.vision_projection(features), dim=-1)
    
    def encode_text(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids, attention_mask)
        pooled = outputs.last_hidden_state[:, 0]
        return F.normalize(self.text_projection(pooled), dim=-1)

# 训练
model = SimpleCLIP()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for images, texts in dataloader:
    image_embeds = model.encode_image(images)
    text_embeds = model.encode_text(texts['input_ids'], texts['attention_mask'])
    
    logits = (image_embeds @ text_embeds.T) / model.temperature
    
    labels = torch.arange(len(images), device=logits.device)
    loss = (F.cross_entropy(logits, labels) + 
            F.cross_entropy(logits.T, labels)) / 2
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

七、总结

多模态AI的核心

  1. 对齐是关键:CLIP用对比学习对齐图文
  2. 架构选择:编码器 vs 编解码器 vs 统一LLM
  3. 数据规模:大规模预训练是基础

前沿方向

  • ImageBind:六模态统一空间
  • Any-to-Any生成:任意模态输入输出
  • 世界模型:物理推理与3D理解

多模态大模型正在成为AI的下一个制高点!🚀


💬 交流与讨论

⚠️ 尚未完成 Giscus 配置。请在 _config.yml 中设置 repo_idcategory_id 后重新部署,即可启用升级后的评论系统。

配置完成后,评论区将自动支持 Markdown 代码高亮与 LaTeX 数学公式渲染,访客回复会同步到 GitHub Discussions,并具备通知功能。