I have an app that basically ingests documents, stores their embeddings in vector form, then responds to queries that relate to the ingested docs.
I have this retrieval.py function which searches through ingested and selected documents and returns which files have matching vectors.
I'm testing a FastAPI service that queries a database using SQLAlchemy's async execution. However, when I mock the database query, execute().scalars().all() returns an empty list instead of the expected [1, 2, 3].
import numpy as np
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import text
from src.backend.database.config import AsyncSessionLocal
from src.services.ingestion_service.embedding_generator import EmbeddingGenerator
import asyncio
class RetrievalService:
"""
Handles retrieval of similar documents based on query embeddings.
"""
def __init__(self):
self.embedding_generator = EmbeddingGenerator()
async def retrieve_relevant_docs(self, query: str, top_k: int = 5):
"""
Converts the query into an embedding and retrieves the most similar documents asynchronously.
"""
async with AsyncSessionLocal() as db:
async with db.begin():
# ✅ Generate embedding in a separate thread
query_embedding = await self.embedding_generator.generate_embedding(query)
print('query_embedding', query_embedding)
# ✅ Convert NumPy array to PostgreSQL-compatible format
query_embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
# ✅ Fetch selected document IDs asynchronously
selected_ids_result = await db.execute(text("SELECT document_id FROM selected_documents;"))
print('selected_ids_result', selected_ids_result)
selected_ids = (await selected_ids_result.scalars()).all()
# ✅ Ensure selected_ids is not empty to prevent SQL errors
if not selected_ids:
selected_ids = [-1] # Dummy ID to avoid SQL failure
# ✅ Execute the vector similarity search query
search_query = text("""
SELECT document_id FROM embeddings
WHERE document_id = ANY(:selected_ids)
ORDER BY vector <-> CAST(:query_embedding AS vector)
LIMIT :top_k;
""").execution_options(cacheable=False)
results = await db.execute(
search_query,
{
"query_embedding": query_embedding_str, # Pass as string
"top_k": top_k,
"selected_ids": selected_ids,
},
)
# document_ids = (await results.scalars()).all()\
print('debug results', vars(results))
document_ids = list(await results.scalars())
print('document_ids', document_ids)
return document_ids
async def get_document_texts(self, document_ids: list[int]):
"""
Fetches the actual document texts for the given document IDs.
"""
if not document_ids:
return []
async with AsyncSessionLocal() as db:
async with db.begin():
query = text("SELECT content FROM documents WHERE id = ANY(:document_ids);")
results = await db.execute(query, {"document_ids": document_ids})
return (await results.scalars()).all()
I have a simple tests file:
import pytest
from unittest.mock import patch, AsyncMock
from src.services.retrieval_service.retrieval import RetrievalService
from sqlalchemy.ext.asyncio import AsyncSession
import pytest
from unittest.mock import patch, AsyncMock
from src.services.retrieval_service.retrieval import RetrievalService
from sqlalchemy.ext.asyncio import AsyncSession
@pytest.mark.asyncio
async def test_retrieve_relevant_docs_valid_query():
service = RetrievalService()
query = "What is AI?"
top_k = 3
with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
mock_generate_embedding.return_value = [0.1] * 384
# Mock DB query for selected documents
mock_scalars_selected = AsyncMock()
mock_scalars_selected.scalars.return_value.all.return_value = [1, 2, 3]
mock_execute.side_effect = [mock_scalars_selected, mock_scalars_selected]
# Mock the `execute` method
mock_execute.return_value = mock_scalars_selected
# Call the method
document_ids = await service.retrieve_relevant_docs(query, top_k)
# Assertion
assert document_ids == [1, 2, 3], f"Expected [1, 2, 3] but got {document_ids}"
@pytest.mark.asyncio
async def test_retrieve_relevant_docs_valid_query_1():
service = RetrievalService()
query = "What is AI?"
top_k = 3
with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
mock_generate_embedding.return_value = [0.1] * 384
# Mock DB query for selected documents
mock_scalars_selected = AsyncMock()
mock_scalars_selected.all = AsyncMock(return_value=[1, 2, 3])
mock_execute.side_effect = [AsyncMock(return_value=AsyncMock(scalars=mock_scalars_selected)), AsyncMock(return_value=AsyncMock(scalars=mock_scalars_selected))]
document_ids = await service.retrieve_relevant_docs(query, top_k)
assert document_ids == [1, 2, 3]
@pytest.mark.asyncio
async def test_retrieve_relevant_docs_no_selected_docs():
service = RetrievalService()
query = "What is AI?"
top_k = 3
with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
mock_generate_embedding.return_value = [0.1] * 384
# Mock DB returning no selected docs
mock_scalars_selected = AsyncMock()
mock_scalars_selected.all.return_value = []
mock_execute.return_value = mock_scalars_selected
document_ids = await service.retrieve_relevant_docs(query, top_k)
assert document_ids == []
@pytest.mark.asyncio
async def test_retrieve_relevant_docs_empty_query():
service = RetrievalService()
query = ""
top_k = 3
with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
mock_generate_embedding.return_value = [0.1] * 384
# Mock DB returning no documents
mock_scalars_selected = AsyncMock()
mock_scalars_selected.all.return_value = []
mock_execute.return_value = mock_scalars_selected
document_ids = await service.retrieve_relevant_docs(query, top_k)
assert document_ids == []
@pytest.mark.asyncio
async def test_get_document_texts_valid_ids():
service = RetrievalService()
document_ids = [1, 2, 3]
with patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
# Mock query result
mock_scalars = AsyncMock()
mock_scalars.all.return_value = ["Document 1 text", "Document 2 text", "Document 3 text"]
mock_execute.return_value = mock_scalars
document_texts = await service.get_document_texts(document_ids)
assert document_texts == ["Document 1 text", "Document 2 text", "Document 3 text"]
@pytest.mark.asyncio
async def test_get_document_texts_no_ids():
service = RetrievalService()
document_ids = []
with patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
document_texts = await service.get_document_texts(document_ids)
assert document_texts == []
I have added so much debugging information, but i do not understand why when I am mocking the retreivalservice to have the side effect [1,2,3] and just yield that value after going through the service. I keep getting error that shows that my mock_execute.side_effect
is not working at all.
These are the logs that get printed:
selected_ids_result debug results {'_mock_return_value': sentinel.DEFAULT, '_mock_parent': None, '_mock_name': None, '_mock_new_name': '()', '_mock_new_parent': , '_mock_sealed': False, '_spec_class': None, '_spec_set': None, '_spec_signature': None, '_mock_methods': None, '_spec_asyncs': [], '_mock_children': {'scalars': , 'str': }, '_mock_wraps': None, '_mock_delegate': None, '_mock_called': False, '_mock_call_args': None, '_mock_call_count': 0, '_mock_call_args_list': [], '_mock_mock_calls': [call.str(), call.scalars(), call.scalars().all()], 'method_calls': [call.scalars()], '_mock_unsafe': False, '_mock_side_effect': None, '_is_coroutine': <object object at 0x0000029C06E16A30>, '_mock_await_count': 0, '_mock_await_args': None, '_mock_await_args_list': [], 'code': , 'str': } document_ids []
and these errors:
short test summary info ==================================== FAILED src/tests/unit/test_retrieve_docs.py::test_retrieve_relevant_docs_valid_query - AssertionError: Expected [1, 2, 3] but got [] FAILED src/tests/unit/test_retrieve_docs.py::test_retrieve_relevant_docs_valid_query_1 - AssertionError: assert [] == [1, 2, 3] FAILED src/tests/unit/test_retrieve_docs.py::test_get_document_texts_valid_ids - AssertionError: assert <coroutine object AsyncMockMixin._execute_mock_call at 0x000002373EA3B...
Observed Behavior: results.scalars().all() unexpectedly returns [], even though I attempted to mock it. Debugging vars(results) shows _mock_side_effect = None, suggesting the mock isn't working as expected.
Expected Behavior: document_ids should contain [1, 2, 3], matching the mocked return value.
What I've Tried: Explicitly setting scalars().all().return_value = [1, 2, 3]. Checking vars(results) for missing attributes. Ensuring mock_execute.side_effect is properly assigned. Calling await session.execute(...).scalars().all() instead of wrapping it in list(). What is the correct way to mock SQLAlchemy's async execution (session.execute().scalars().all()) in a FastAPI test using AsyncMock? Or Can someone point why my is not behaving as I expect it to?
I feel if i fix one test, all my tests should get fixed the same way. I am not new to python, but very new to sqlalchemy