1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
| import json from pycocotools.coco import COCO from pathlib import Path
def process_metadata(): data_root = Path("/auto-tmp/coco2017") ann_path = data_root / "annotations/captions_train2017.json" output_path = data_root / "processed/metadata.json" output_path.parent.mkdir(parents=True, exist_ok=True) coco = COCO(ann_path) metadata = [] for img_id in coco.getImgIds(): anns = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) img_path = data_root / f"train2017/{img_id:012d}.jpg" for ann in anns: metadata.append({ "id": ann["id"], "image_path": str(img_path), "caption": ann["caption"] }) with open(output_path, "w") as f: json.dump(metadata, f, indent=2)
process_metadata()
import torch import clip import numpy as np from tqdm import tqdm from PIL import Image
class CLIPFeatureExtractor: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model, self.preprocess = clip.load("RN50x4", device=self.device) self.model.eval() def _batch_processor(self, data, batch_size, process_fn): """通用批处理工具""" results = [] for i in tqdm(range(0, len(data), batch_size)): batch = data[i:i+batch_size] results.append(process_fn(batch)) return np.concatenate(results) def extract_image_features(self): """提取图像特征""" with open("/auto-tmp/coco2017/processed/metadata.json") as f: metadata = json.load(f) unique_images = list({item["image_path"]: item for item in metadata}.values()) def process_batch(batch): images = [] for item in batch: img = Image.open(item["image_path"]).convert("RGB") images.append(self.preprocess(img)) with torch.no_grad(), torch.cuda.amp.autocast(): tensor = torch.stack(images).to(self.device) features = self.model.encode_image(tensor) features /= features.norm(dim=-1, keepdim=True) return features.cpu().numpy().astype("float32") batch_size = 512 features = self._batch_processor(unique_images, batch_size, process_batch) np.save("/auto-tmp/coco2017/processed/image_features.npy", features) def extract_text_features(self): """提取文本特征""" with open("/auto-tmp/coco2017/processed/metadata.json") as f: metadata = json.load(f) def process_batch(batch): texts = [item["caption"] for item in batch] with torch.no_grad(): inputs = clip.tokenize(texts, truncate=True).to(self.device) features = self.model.encode_text(inputs) features /= features.norm(dim=-1, keepdim=True) return features.cpu().numpy().astype("float32") batch_size = 2048 features = self._batch_processor(metadata, batch_size, process_batch) np.save("/auto-tmp/coco2017/processed/text_features.npy", features)
extractor = CLIPFeatureExtractor() extractor.extract_image_features() extractor.extract_text_features()
import faiss import numpy as np
class VectorIndexer: def __init__(self): self.res = faiss.StandardGpuResources() def build_index(self): text_features = np.load("/auto-tmp/coco2017/processed/text_features.npy") faiss.normalize_L2(text_features) dim = text_features.shape[1] index = faiss.IndexFlatIP(dim) gpu_index = faiss.index_cpu_to_gpu(self.res, 0, index) chunk_size = 50000 for i in range(0, len(text_features), chunk_size): gpu_index.add(text_features[i:i+chunk_size]) cpu_index = faiss.index_gpu_to_cpu(gpu_index) faiss.write_index(cpu_index, "/auto-tmp/coco2017/processed/faiss_index.index")
indexer = VectorIndexer() indexer.build_index()
class ImageCaptionRetriever: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model, self.preprocess = clip.load("RN50x4", device=self.device) self.index = faiss.read_index("/auto-tmp/coco2017/processed/faiss_index.index") with open("/auto-tmp/coco2017/processed/metadata.json") as f: self.metadata = json.load(f) def _format_output(self, captions): """按照模板格式化输出""" base_template = """I am an intelligent image captioning bot. Similar images have the following captions: {captions}""" caption_list = "\n".join([f"<{caption}>" for caption in captions]) return base_template.format(captions=caption_list) def retrieve(self, image_path, top_k=5): image = Image.open(image_path).convert("RGB") tensor = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): features = self.model.encode_image(tensor) features /= features.norm(dim=-1, keepdim=True) gpu_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, self.index) _, indices = gpu_index.search(features.cpu().numpy().astype("float32"), top_k) results = [self.metadata[idx]["caption"] for idx in indices[0]] return self._format_output(results)
retriever = ImageCaptionRetriever() output = retriever.retrieve("/path/to/query_image.jpg") print(output)
|