demo-code-search/code_search/index/upload_code.py

94 lines
2.6 KiB
Python

from pathlib import Path
from tqdm import tqdm
from qdrant_client.http import models as rest
import qdrant_client
import numpy as np
import json
from code_search.config import QDRANT_URL, QDRANT_API_KEY, DATA_DIR, QDRANT_CODE_COLLECTION_NAME
from code_search.model.encoder import UniXcoderEmbeddingsProvider
code_keys = [
"code_snippet",
"body",
"signature",
"name",
]
def encode_and_upload():
client = qdrant_client.QdrantClient(
QDRANT_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True,
)
collection_name = QDRANT_CODE_COLLECTION_NAME
input_file = Path(DATA_DIR) / "qdrant_snippets.jsonl"
encoder = UniXcoderEmbeddingsProvider()
input_file = Path(DATA_DIR) / input_file
output_file = Path(DATA_DIR) / f"{collection_name}.npy"
if not input_file.exists():
raise RuntimeError(f"File {input_file} does not exist. Skipping")
if output_file.exists():
print(f"File {output_file} already exists. Skipping encoding.")
embeddings = np.load(str(output_file)).tolist()
else:
print(f"Preparing the output for {output_file}")
embeddings = []
with open(input_file, "r") as fp:
for line in tqdm(fp):
line_dict = json.loads(line)
body = None
for code_key in code_keys:
body = line_dict.get(code_key)
if body is not None:
break
docstring = line_dict.get("docstring")
if body is None or len(body) == 0:
continue
embedding = encoder.embed_code(body, docstring)
embeddings.append(embedding)
np.save(str(output_file), np.array(embeddings))
payloads = []
with open(input_file, "r") as fp:
for line in tqdm(fp):
line_dict = json.loads(line)
payloads.append(line_dict)
print(f"Embeddings shape: ({len(embeddings)}, {len(embeddings[0])})")
print(f"Recreating the collection {collection_name}")
client.recreate_collection(
collection_name=collection_name,
vectors_config=rest.VectorParams(
size=len(embeddings[1]),
distance=rest.Distance.COSINE,
)
)
print(f"Storing data in the collection {collection_name}")
response = client.upsert(
collection_name=collection_name,
points=rest.Batch(
ids=[i for i, _ in enumerate(embeddings)],
vectors=embeddings,
payloads=payloads,
),
)
print(response)
if __name__ == '__main__':
encode_and_upload()