
| import json from pycocotools.coco import COCO from pathlib import Path
def process_metadata(): data_root = Path("/auto-tmp/coco2017") ann_path = data_root / "annotations/captions_train2017.json" output_path = data_root / "processed/metadata.json" output_path.parent.mkdir(parents=True, exist_ok=True) coco = COCO(ann_path) metadata = [] for img_id in coco.getImgIds(): anns = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) img_path = data_root / f"train2017/{img_id:012d}.jpg" for ann in anns: metadata.append({ "id": ann["id"], "image_path": str(img_path), "caption": ann["caption"] }) with open(output_path, "w") as f: json.dump(metadata, f, indent=2)
process_metadata()
import torch import clip import numpy as np from tqdm import tqdm from PIL import Image
class CLIPFeatureExtractor: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model, self.preprocess = clip.load("RN50x4", device=self.device) self.model.eval() def _batch_processor(self, data, batch_size, process_fn): """通用批处理工具""" results = [] for i in tqdm(range(0, len(data), batch_size)): batch = data[i:i+batch_size] results.append(process_fn(batch)) return np.concatenate(results) def extract_image_features(self): """提取图像特征""" with open("/auto-tmp/coco2017/processed/metadata.json") as f: metadata = json.load(f) unique_images = list({item["image_path"]: item for item in metadata}.values()) def process_batch(batch): images = [] for item in batch: img = Image.open(item["image_path"]).convert("RGB") images.append(self.preprocess(img)) with torch.no_grad(), torch.cuda.amp.autocast(): tensor = torch.stack(images).to(self.device) features = self.model.encode_image(tensor) features /= features.norm(dim=-1, keepdim=True) return features.cpu().numpy().astype("float32") batch_size = 512 features = self._batch_processor(unique_images, batch_size, process_batch) np.save("/auto-tmp/coco2017/processed/image_features.npy", features) def extract_text_features(self): """提取文本特征""" with open("/auto-tmp/coco2017/processed/metadata.json") as f: metadata = json.load(f) def process_batch(batch): texts = [item["caption"] for item in batch] with torch.no_grad(): inputs = clip.tokenize(texts, truncate=True).to(self.device) features = self.model.encode_text(inputs) features /= features.norm(dim=-1, keepdim=True) return features.cpu().numpy().astype("float32") batch_size = 2048 features = self._batch_processor(metadata, batch_size, process_batch) np.save("/auto-tmp/coco2017/processed/text_features.npy", features)
extractor = CLIPFeatureExtractor() extractor.extract_image_features() extractor.extract_text_features()
import faiss import numpy as np
class VectorIndexer: def __init__(self): self.res = faiss.StandardGpuResources() def build_index(self): text_features = np.load("/auto-tmp/coco2017/processed/text_features.npy") faiss.normalize_L2(text_features) dim = text_features.shape[1] index = faiss.IndexFlatIP(dim) gpu_index = faiss.index_cpu_to_gpu(self.res, 0, index) chunk_size = 50000 for i in range(0, len(text_features), chunk_size): gpu_index.add(text_features[i:i+chunk_size]) cpu_index = faiss.index_gpu_to_cpu(gpu_index) faiss.write_index(cpu_index, "/auto-tmp/coco2017/processed/faiss_index.index")
indexer = VectorIndexer() indexer.build_index()
class ImageCaptionRetriever: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model, self.preprocess = clip.load("RN50x4", device=self.device) self.index = faiss.read_index("/auto-tmp/coco2017/processed/faiss_index.index") with open("/auto-tmp/coco2017/processed/metadata.json") as f: self.metadata = json.load(f) def _format_output(self, captions): """按照模板格式化输出""" base_template = """I am an intelligent image captioning bot. Similar images have the following captions: {captions}""" caption_list = "\n".join([f"<{caption}>" for caption in captions]) return base_template.format(captions=caption_list) def retrieve(self, image_path, top_k=5): image = Image.open(image_path).convert("RGB") tensor = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): features = self.model.encode_image(tensor) features /= features.norm(dim=-1, keepdim=True) gpu_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, self.index) _, indices = gpu_index.search(features.cpu().numpy().astype("float32"), top_k) results = [self.metadata[idx]["caption"] for idx in indices[0]] return self._format_output(results)
retriever = ImageCaptionRetriever() output = retriever.retrieve("/path/to/query_image.jpg") print(output)
|