最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - AsyncMock Returns Empty List When Mocking SQLAlchemy Async Execution in FastAPI - Stack Overflow

programmeradmin3浏览0评论

I have an app that basically ingests documents, stores their embeddings in vector form, then responds to queries that relate to the ingested docs.

I have this retrieval.py function which searches through ingested and selected documents and returns which files have matching vectors.

I'm testing a FastAPI service that queries a database using SQLAlchemy's async execution. However, when I mock the database query, execute().scalars().all() returns an empty list instead of the expected [1, 2, 3].

import numpy as np
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import text
from src.backend.database.config import AsyncSessionLocal
from src.services.ingestion_service.embedding_generator import EmbeddingGenerator
import asyncio


class RetrievalService:
    """
    Handles retrieval of similar documents based on query embeddings.
    """

    def __init__(self):
        self.embedding_generator = EmbeddingGenerator()

    async def retrieve_relevant_docs(self, query: str, top_k: int = 5):
        """
        Converts the query into an embedding and retrieves the most similar documents asynchronously.
        """
        async with AsyncSessionLocal() as db:
            async with db.begin():
                # ✅ Generate embedding in a separate thread
                query_embedding = await self.embedding_generator.generate_embedding(query)

                
                print('query_embedding', query_embedding)

                # ✅ Convert NumPy array to PostgreSQL-compatible format
                query_embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"

                # ✅ Fetch selected document IDs asynchronously
                selected_ids_result = await db.execute(text("SELECT document_id FROM selected_documents;"))
                print('selected_ids_result', selected_ids_result)
                selected_ids = (await selected_ids_result.scalars()).all()

                # ✅ Ensure selected_ids is not empty to prevent SQL errors
                if not selected_ids:
                    selected_ids = [-1]  # Dummy ID to avoid SQL failure

                # ✅ Execute the vector similarity search query
                search_query = text("""
                    SELECT document_id FROM embeddings
                    WHERE document_id = ANY(:selected_ids)
                    ORDER BY vector <-> CAST(:query_embedding AS vector)
                    LIMIT :top_k;
                """).execution_options(cacheable=False)

                results = await db.execute(
                    search_query,
                    {
                        "query_embedding": query_embedding_str,  # Pass as string
                        "top_k": top_k,
                        "selected_ids": selected_ids,
                    },
                )

                # document_ids = (await results.scalars()).all()\
                print('debug results', vars(results))
                document_ids = list(await results.scalars())

                print('document_ids', document_ids)
                return document_ids

    async def get_document_texts(self, document_ids: list[int]):
        """
        Fetches the actual document texts for the given document IDs.
        """
        if not document_ids:
            return []

        async with AsyncSessionLocal() as db:
            async with db.begin():
                query = text("SELECT content FROM documents WHERE id = ANY(:document_ids);")
                results = await db.execute(query, {"document_ids": document_ids})
                return (await results.scalars()).all()

I have a simple tests file:

import pytest
from unittest.mock import patch, AsyncMock
from src.services.retrieval_service.retrieval import RetrievalService
from sqlalchemy.ext.asyncio import AsyncSession

import pytest
from unittest.mock import patch, AsyncMock
from src.services.retrieval_service.retrieval import RetrievalService
from sqlalchemy.ext.asyncio import AsyncSession

@pytest.mark.asyncio
async def test_retrieve_relevant_docs_valid_query():
    service = RetrievalService()
    query = "What is AI?"
    top_k = 3

    with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
         patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        
        mock_generate_embedding.return_value = [0.1] * 384

        # Mock DB query for selected documents
        mock_scalars_selected = AsyncMock()
        mock_scalars_selected.scalars.return_value.all.return_value = [1, 2, 3]
        
        mock_execute.side_effect = [mock_scalars_selected, mock_scalars_selected]
        # Mock the `execute` method
        mock_execute.return_value = mock_scalars_selected

        # Call the method
        document_ids = await service.retrieve_relevant_docs(query, top_k)

        # Assertion
        assert document_ids == [1, 2, 3], f"Expected [1, 2, 3] but got {document_ids}"


@pytest.mark.asyncio
async def test_retrieve_relevant_docs_valid_query_1():
    service = RetrievalService()
    query = "What is AI?"
    top_k = 3

    with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
         patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        
        mock_generate_embedding.return_value = [0.1] * 384

        # Mock DB query for selected documents
        mock_scalars_selected = AsyncMock()
        mock_scalars_selected.all = AsyncMock(return_value=[1, 2, 3])
        mock_execute.side_effect = [AsyncMock(return_value=AsyncMock(scalars=mock_scalars_selected)), AsyncMock(return_value=AsyncMock(scalars=mock_scalars_selected))]

        document_ids = await service.retrieve_relevant_docs(query, top_k)
        assert document_ids == [1, 2, 3]


@pytest.mark.asyncio
async def test_retrieve_relevant_docs_no_selected_docs():
    service = RetrievalService()
    query = "What is AI?"
    top_k = 3

    with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
         patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        
        mock_generate_embedding.return_value = [0.1] * 384

        # Mock DB returning no selected docs
        mock_scalars_selected = AsyncMock()
        mock_scalars_selected.all.return_value = []
        mock_execute.return_value = mock_scalars_selected

        document_ids = await service.retrieve_relevant_docs(query, top_k)
        assert document_ids == []

@pytest.mark.asyncio
async def test_retrieve_relevant_docs_empty_query():
    service = RetrievalService()
    query = ""
    top_k = 3

    with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
         patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        
        mock_generate_embedding.return_value = [0.1] * 384

        # Mock DB returning no documents
        mock_scalars_selected = AsyncMock()
        mock_scalars_selected.all.return_value = []
        mock_execute.return_value = mock_scalars_selected

        document_ids = await service.retrieve_relevant_docs(query, top_k)
        assert document_ids == []

@pytest.mark.asyncio
async def test_get_document_texts_valid_ids():
    service = RetrievalService()
    document_ids = [1, 2, 3]

    with patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        # Mock query result
        mock_scalars = AsyncMock()
        mock_scalars.all.return_value = ["Document 1 text", "Document 2 text", "Document 3 text"]
        mock_execute.return_value = mock_scalars

        document_texts = await service.get_document_texts(document_ids)
        assert document_texts == ["Document 1 text", "Document 2 text", "Document 3 text"]

@pytest.mark.asyncio
async def test_get_document_texts_no_ids():
    service = RetrievalService()
    document_ids = []

    with patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        document_texts = await service.get_document_texts(document_ids)
        assert document_texts == []

I have added so much debugging information, but i do not understand why when I am mocking the retreivalservice to have the side effect [1,2,3] and just yield that value after going through the service. I keep getting error that shows that my mock_execute.side_effect is not working at all.

These are the logs that get printed:

selected_ids_result debug results {'_mock_return_value': sentinel.DEFAULT, '_mock_parent': None, '_mock_name': None, '_mock_new_name': '()', '_mock_new_parent': , '_mock_sealed': False, '_spec_class': None, '_spec_set': None, '_spec_signature': None, '_mock_methods': None, '_spec_asyncs': [], '_mock_children': {'scalars': , 'str': }, '_mock_wraps': None, '_mock_delegate': None, '_mock_called': False, '_mock_call_args': None, '_mock_call_count': 0, '_mock_call_args_list': [], '_mock_mock_calls': [call.str(), call.scalars(), call.scalars().all()], 'method_calls': [call.scalars()], '_mock_unsafe': False, '_mock_side_effect': None, '_is_coroutine': <object object at 0x0000029C06E16A30>, '_mock_await_count': 0, '_mock_await_args': None, '_mock_await_args_list': [], 'code': , 'str': } document_ids []

and these errors:

short test summary info ==================================== FAILED src/tests/unit/test_retrieve_docs.py::test_retrieve_relevant_docs_valid_query - AssertionError: Expected [1, 2, 3] but got [] FAILED src/tests/unit/test_retrieve_docs.py::test_retrieve_relevant_docs_valid_query_1 - AssertionError: assert [] == [1, 2, 3] FAILED src/tests/unit/test_retrieve_docs.py::test_get_document_texts_valid_ids - AssertionError: assert <coroutine object AsyncMockMixin._execute_mock_call at 0x000002373EA3B...

Observed Behavior: results.scalars().all() unexpectedly returns [], even though I attempted to mock it. Debugging vars(results) shows _mock_side_effect = None, suggesting the mock isn't working as expected.

Expected Behavior: document_ids should contain [1, 2, 3], matching the mocked return value.

What I've Tried: Explicitly setting scalars().all().return_value = [1, 2, 3]. Checking vars(results) for missing attributes. Ensuring mock_execute.side_effect is properly assigned. Calling await session.execute(...).scalars().all() instead of wrapping it in list(). What is the correct way to mock SQLAlchemy's async execution (session.execute().scalars().all()) in a FastAPI test using AsyncMock? Or Can someone point why my is not behaving as I expect it to?

I feel if i fix one test, all my tests should get fixed the same way. I am not new to python, but very new to sqlalchemy

发布评论

评论列表(0)

  1. 暂无评论