Spaces:
Runtime error
Runtime error
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import ScoredPoint | |
| from embedding import Embedding | |
| from model.document import Document | |
| from model.record import Record | |
| from model.user import User | |
| from qdrant_client.http import models | |
| import uuid | |
| import tqdm | |
| class Index: | |
| type: str | |
| def load_or_update_document(self, user: User, document: Document, progress: tqdm.tqdm = None): | |
| pass | |
| def remove_document(self, user: User, document: Document): | |
| pass | |
| def query_index(self, user: User, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: | |
| pass | |
| def query_document(self, user: User, document: Document, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: | |
| pass | |
| def contains(self, user: User, document: Document) -> bool: | |
| pass | |
| class QDrantVectorStore(Index): | |
| _client: QdrantClient | |
| _embedding: Embedding | |
| collection_name: str | |
| batch_size: int = 10 | |
| type: str = 'qdrant' | |
| def __init__( | |
| self, | |
| client: QdrantClient, | |
| embedding: Embedding, | |
| collection_name: str): | |
| self._embedding = embedding | |
| self.collection_name = collection_name | |
| self._client = client | |
| def _response_to_records(self, response: list[ScoredPoint]) -> list[Record]: | |
| for point in response: | |
| meta_data = point.payload['meta_data'] | |
| yield Record( | |
| embedding=point.vector, | |
| meta_data= meta_data, | |
| content=point.payload['content'], | |
| document_id=point.payload['document_id'], | |
| timestamp=point.payload['timestamp'], | |
| ) | |
| def create_collection(self): | |
| self._client.recreate_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=models.VectorParams( | |
| size=self._embedding.vector_size, | |
| distance=models.Distance.COSINE), | |
| ) | |
| def if_collection_exists(self) -> bool: | |
| try: | |
| self._client.get_collection(self.collection_name) | |
| return True | |
| except Exception: | |
| return False | |
| def create_collection_if_not_exists(self): | |
| if not self.if_collection_exists(): | |
| self.create_collection() | |
| def load_or_update_document(self, user: User, document: Document, progress: tqdm.tqdm = None): | |
| self.create_collection_if_not_exists() | |
| if self.contains(user, document): | |
| self.remove_document(user, document) | |
| group_id = user.user_name | |
| # upsert records in batch | |
| records = document.load_records() | |
| records = list(records) | |
| batch_range = range(0, len(records), self.batch_size) | |
| if progress is not None: | |
| batch_range = progress(batch_range) | |
| for i in batch_range: | |
| batch = records[i:i+self.batch_size] | |
| uuids = [str(uuid.uuid4()) for _ in batch] | |
| payloads = [{ | |
| 'content': record.content, | |
| 'meta_data': record.meta_data, | |
| 'document_id': record.document_id, | |
| 'group_id': group_id, | |
| 'timestamp': record.timestamp, | |
| } for record in batch] | |
| vectors = [record.embedding for record in batch] | |
| self._client.upsert( | |
| collection_name=self.collection_name, | |
| points=models.Batch( | |
| payloads=payloads, | |
| ids=uuids, | |
| vectors=vectors, | |
| ), | |
| ) | |
| def remove_document(self, user: User, document: Document): | |
| if not self.if_collection_exists(): | |
| return | |
| document_id = document.name | |
| self._client.delete( | |
| collection_name=self.collection_name, | |
| points_selector=models.FilterSelector( | |
| filter=models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="document_id", | |
| match=models.MatchValue(value=document_id) | |
| ), | |
| models.FieldCondition( | |
| key="group_id", | |
| match=models.MatchValue( | |
| value=user.user_name, | |
| ), | |
| ) | |
| ] | |
| ) | |
| ) | |
| ) | |
| def contains(self, user: User, document: Document) -> bool: | |
| document_id = document.name | |
| group_id = user.user_name | |
| count = self._client.count( | |
| collection_name=self.collection_name, | |
| count_filter=models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="document_id", | |
| match=models.MatchValue(value=document_id) | |
| ), | |
| models.FieldCondition( | |
| key="group_id", | |
| match=models.MatchValue( | |
| value=group_id, | |
| ), | |
| ) | |
| ] | |
| ), | |
| exact=True, | |
| ) | |
| return count.count > 0 | |
| def query_index(self, user: User, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: | |
| if not self.if_collection_exists(): | |
| return [] | |
| response = self._client.search( | |
| collection_name=self.collection_name, | |
| query_vector=self._embedding.generate_embedding(query), | |
| limit=top_k, | |
| query_filter= models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="group_id", | |
| match=models.MatchValue( | |
| value=user.user_name, | |
| ), | |
| ) | |
| ] | |
| ), | |
| score_threshold=threshold, | |
| ) | |
| return self._response_to_records(response) | |
| def query_document(self, user: User, document: Document, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: | |
| if not self.if_collection_exists(): | |
| return [] | |
| response = self._client.search( | |
| collection_name=self.collection_name, | |
| query_vector=self._embedding.generate_embedding(query), | |
| limit=top_k, | |
| query_filter= models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="document_id", | |
| match=models.MatchValue(value=document.name) | |
| ), | |
| models.FieldCondition( | |
| key="group_id", | |
| match=models.MatchValue(value=user.user_name), | |
| ) | |
| ] | |
| ), | |
| score_threshold=threshold, | |
| ) | |
| return self._response_to_records(response) | |