Loading [MathJax]/jax/output/HTML-CSS/config.js

2025/02/20

テクノロジー

APIのモックを用いたユニットテストとE2Eテストについて

この記事の目次

今回はAPIのモックを用いたユニットテストとE2Eテストについて実際のコードを使いながら紹介しようと思います。

モックを用いたユニットテストの概要

モックとは

まずモックとは何かについて説明します。

テストしたい関数が他のクラスに依存していることはよくあると思います。
例えば、SNSのとあるユーザーの投稿を取得するAPIのサービスクラスは投稿が公開か未公開か確認のために投稿のレポジトリクラスに依存し、またユーザーの存在を確認するためにユーザーのレポジトリクラスにも依存しています。
この状況下において、モックを使わずにUTを実装すると、ユーザーのレポジトリクラスのUTが失敗した場合、サービスクラスのUTも失敗しているということになります。そのため、原因特定に時間がかかります。

モックを使って実装すると、依存しているクラスや関数が想定通りの挙動をするように設定できるため、ユーザーのレポジトリクラスのUTが失敗した場合でも、サービスクラスのUTは成功します。そのため、瞬時にユーザーのレポジトリクラスのみでバグが生じていることがわかります。
つまり、モックとはUTの責任範囲を明確にし、 UTを実装しやすくする存在です。

依存注入

ただ、注意しなくてはいけないのは UTの対象関数の内、UTが制御できるのは対象関数の呼び方のみであるということです。
つまり、モックを使ってUTを制御するためには、モックするクラスを関数またはクラスの引数に設定する必要があります。

そのため、対象関数またはクラスの引数はクラスを注入できるように実装する必要があります。
これを依存注入と呼びます。

モックを用いたUTの実例

モックや依存注入について説明が終わったため、実際のコードを使って説明したいと思います。
今回はSNSのとあるユーザーの投稿を取得するAPIとそのUTコードを実装しました。

コントローラー

コントローラーのコードは以下のようになっています。
get_post_info関数の引数にサービスクラスを依存注入しています。

/controllers/get_post_controller.py<code>from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends

from app.api_schemas.get_post_schema import (GetPostRequest, GetPostResponse,
                                             GetPostSchema)
from app.core.container import Container
from app.services.get_post_service import GetPostService

router = APIRouter()


@router.get("/posts/{post_id}")
@inject
def get_post_info(
    get_post_request: GetPostRequest = Depends(),
    service: GetPostService = Depends(Provide[Container.get_post_service]),
) -> GetPostResponse:
    if post := service.get_post_info(
        get_post_request.post_id, get_post_request.user_id
    ):
        return GetPostResponse(
            result=True,
            post=GetPostSchema(title=post.title, description=post.description),
        )
    return GetPostResponse(result=False, post=None)
</code>

これに対するUTコードは以下のようになっています。
mock_get_post_service関数でサービスのモックを作成し、各テストケースで利用しています。

コントローラーのget_post_info関数の中で使うサービスのメソッドの返り値をモックで設定することで関数内の条件分岐を制御しています。

/tests/controllers/test_get_post_controller.pyfrom unittest.mock import MagicMock

import pytest
import requests

from app.api_schemas.get_post_schema import (GetPostRequest, GetPostResponse,
                                             GetPostSchema)
from app.controllers.get_post_controller import get_post_info
from app.models.post import PostTable
from app.models.user import UserTable
from tests.base_test import BaseTest

@pytest.fixture()
def mock_get_post_service():
    return MagicMock()

def test_get_post_succeeds(mock_get_post_service):
    mock_get_post_service.get_post_info.return_value = (
        PostTable.test_public_post_by_user1_data()
    )

    request = GetPostRequest(post_id=1, user_id=1)
    response = get_post_info(get_post_request=request, service=mock_get_post_service)
    assert response == GetPostResponse(
        result=True,
        post=GetPostSchema(
            title=PostTable.test_public_post_by_user1_data().title,
            description=PostTable.test_public_post_by_user1_data().description,
        ),
    )

def test_get_post_fails(mock_get_post_service):
    mock_get_post_service.get_post_info.return_value = None

    request = GetPostRequest(post_id=1, user_id=1)
    response = get_post_info(get_post_request=request, service=mock_get_post_service)
    assert response == GetPostResponse(result=False, post=None)

サービス

サービスのコードは以下のようになっています。

サービスはクラスにまとめているため、クラスの__init__関数で依存するユーザーレポジトリクラスと投稿レポジトリクラスを注入しています。
これによってクラス内の関数のインスタンスから依存先を利用できるようになっています。

/services/get_post_service.py<code>from app.models.post import PostTable
from app.repositories.post_repository import PostRepository
from app.repositories.user_repository import UserRepository


class GetPostService:
    def __init__(
        self, post_repository: PostRepository, user_repository: UserRepository
    ):
        self.post_repository = post_repository
        self.user_repository = user_repository

    def get_post_info(self, post_id, user_id) -> PostTable:
        if not self.user_repository.get_user(user_id):
            return None
        post = self.post_repository.get_post(post_id)
        if not post or not self.__is_visible(post, user_id):
            return None
        return post

    def __is_visible(self, post: PostTable, user_id) -> bool:
        if post.user_id == user_id:
            return True
        elif not post.is_private:
            return True
        else:
            return False
</code>

これに対するUTコードは以下のようになっています。

get_post_service関数で依存するユーザーレポジトリクラスと投稿レポジトリクラスをモックしています。
サービスのget_post_info関数の中で使うレポジトリのメソッドの返り値をモックで設定することで関数内の条件分岐を制御しています。

tests/services/test_get_post_service.py<code>from datetime import datetime
from unittest.mock import MagicMock

import pytest

from app.models.post import PostTable
from app.models.user import UserTable
from app.services.get_post_service import GetPostService


@pytest.fixture()
def get_post_service():
    return GetPostService(post_repository=MagicMock(), user_repository=MagicMock())


def test_non_existing_user(get_post_service):
    non_existing_user_id = 1
    get_post_service.user_repository.get_user.return_value = None
    assert get_post_service.get_post_info(1, non_existing_user_id) == None


def test_non_existing_post(get_post_service):
    non_existing_post_id = 1
    get_post_service.user_repository.get_user.return_value = (
        UserTable.test_not_login_user1_data()
    )
    get_post_service.post_repository.get_post.return_value = None
    assert (
        get_post_service.get_post_info(
            non_existing_post_id, UserTable.test_not_login_user1_data().id
        )
        == None
    )


def test_get_private_post_from_non_author(get_post_service):
    get_post_service.user_repository.get_user.return_value = (
        UserTable.test_not_login_user1_data()
    )
    get_post_service.post_repository.get_post.return_value = (
        PostTable.test_private_post_by_user1_data()
    )
    assert (
        get_post_service.get_post_info(
            UserTable.test_not_login_user1_data().id,
            PostTable.test_private_post_by_user1_data().id,
        )
        == None
    )


def test_get_private_post_from_author(get_post_service):
    get_post_service.user_repository.get_user.return_value = (
        UserTable.test_not_login_user1_data()
    )
    get_post_service.post_repository.get_post.return_value = (
        PostTable.test_private_post_by_user1_data()
    )
    assert (
        get_post_service.get_post_info(
            PostTable.test_private_post_by_user1_data().id,
            UserTable.test_not_login_user1_data().id,
        )
        == get_post_service.post_repository.get_post.return_value
    )


def test_get_public_post_from_non_author(get_post_service):
    get_post_service.user_repository.get_user.return_value = (
        UserTable.test_login_user2_data()
    )
    get_post_service.post_repository.get_post.return_value = (
        PostTable.test_public_post_by_user1_data()
    )
    assert (
        get_post_service.get_post_info(
            PostTable.test_public_post_by_user1_data().id,
            UserTable.test_login_user2_data().id,
        )
        == get_post_service.post_repository.get_post.return_value
    )
</code>

レポジトリ

レポジトリのコードは以下のようになっています。

レポジトリはクラスにまとめているため、クラスの__init__関数で依存するDBを注入しています。
DBを注入することで開発環境のDBとは別のDBにデータを入れることができるため、開発環境のDBに影響を与えずに済みます。
これによってクラス内の関数のインスタンスから依存先を利用できるようになっています。

/repositories/post_repository.py<code>from pydantic import BaseModel
from sqlalchemy.orm import Session

from app.models.post import PostTable
from app.models.user import UserTable


class PostRepository:
    def __init__(self, db: Session):
        self.db = db

    def get_post(self, post_id) -> PostTable:
        return (
            self.db.query(PostTable)
            .join(UserTable, UserTable.id == PostTable.user_id)
            .filter(PostTable.id == post_id)
            .first()
        )
</code>

これに対するUTコードは以下のようになっています。

レポジトリはAPIの最奥層であるため、何もモックせずに実際にテスト用DBにデータを入れた上でUTを書いています。
テスト用DBとアプリ用DBの切り替えはbase_test.pyで行っていますが、ここでは省略します。

/tests/repositories/test_post_repository.py<code>from app.helpers.helper import get_datetime_now_db_format
from app.models.post import PostTable
from app.models.user import UserTable
from app.repositories.post_repository import PostRepository
from tests.base_test import BaseTest


class TestPostRepository(BaseTest):
    @classmethod
    def _initialize_repository(cls):
        cls.post_repository = PostRepository(cls.db)

    @classmethod
    def _insert_data(cls):
        cls.db.add_all(
            [
                PostTable.test_public_post_by_user1_data(),
                UserTable.test_not_login_user1_data(),
            ]
        )
        cls.db.commit()

    @classmethod
    def test_get_existing_post(cls):
        response = cls.post_repository.get_post(
            PostTable.test_public_post_by_user1_data().id
        )
        assert response.id == PostTable.test_public_post_by_user1_data().id
        assert response.title == PostTable.test_public_post_by_user1_data().title
        assert (
            response.description
            == PostTable.test_public_post_by_user1_data().description
        )
        assert response.user_id == PostTable.test_public_post_by_user1_data().user_id
        assert (
            response.is_private == PostTable.test_public_post_by_user1_data().is_private
        )

    @classmethod
    def test_get_non_existing_post(cls):
        non_existing_post_id = 2
        response = cls.post_repository.get_post(non_existing_post_id)
        assert response == None
</code>

E2Eテスト

E2Eテストとはシステム全体をテストするものです。
E2Eテストを実行することで関数間の値の受け渡しが正常であることを担保し、UTのみではカバーできないところをカバーし、バグが発生する可能性を下げることができます。

このレポジトリはAPIしか作成していないため、フロントエンドの挙動までは確認しません。
ここでは特定のパスにリクエストが来てからレスポンスが返されるまでの一連の動作を確認します。
そのため、ここでも実際にデータを入れます。

/tests/controllers/test_get_post_controller.py<code>from unittest.mock import MagicMock

import pytest
import requests

from app.api_schemas.get_post_schema import (GetPostRequest, GetPostResponse,
                                             GetPostSchema)
from app.controllers.get_post_controller import get_post_info
from app.models.post import PostTable
from app.models.user import UserTable
from tests.base_test import BaseTest

class TestGetPostController(BaseTest):
    @classmethod
    def _insert_data(cls):
        cls.db.add_all(
            [
                UserTable.test_not_login_user1_data(),
                PostTable.test_public_post_by_user1_data(),
            ]
        )
        cls.db.commit()

    @classmethod
    def test_e2e(cls):
        post_id = str(PostTable.test_public_post_by_user1_data().id)
        user_id = str(UserTable.test_not_login_user1_data().id)
        response = cls.client.get(
            "/posts/" + post_id, params={"post_id": post_id, "user_id": user_id}
        )
        assert response.status_code == 200
        assert response.json() == {
            "result": True,
            "post": {
                "title": PostTable.test_public_post_by_user1_data().title,
                "description": PostTable.test_public_post_by_user1_data().description,
            },
        }
</code>

最後に

APIのモックを用いたユニットテストは、自分が書いたコードが仕様を正しく反映していることを迅速に確認するための非常に効率的な手法です。
一定期間が経過して仕様を忘れてしまった場合でも、仕様がコードとして明確に表現されているため、再確認が容易になります。

また、E2Eテストはユニットテストだけでは見落としがちな、全体の動作を確認するのに非常に有効です。
これにより、実際の動作環境での問題を早期に発見し、修正することができます。

ぜひ、これらのテスト手法を取り入れて、効率的にAPI開発を進めていただければと思います。
今後も、テストの重要性を意識しながら、より良いソフトウェアを作り上げていきましょう。

※本記事は2025年02月時点の情報です。

著者:マイナビエンジニアブログ編集部