SQLAlchemyでの多対多(many to many)の保存方法

Web開発

みなさん、こんにちは!

本日は、SQLAlchemyでの多対多(many to many)の保存方法について記載します。

私の場合、今回試すものと条件は異なるのですが、少しはまりました。。ので。参考になれば幸いです。

SQLAlchemyのモデル定義についてはこちらで書いてます。

今回も☝のリンク先の記事と同じように、こちらのfastapiのテンプレートを使用してます。

GitHub - tiangolo/full-stack-fastapi-postgresql: Full stack, modern web application generator. Using FastAPI, PostgreSQL as database, Docker, automatic HTTPS and more.
Full stack, modern web application generator. Using FastAPI, PostgreSQL as database, Docker, automatic HTTPS and more. - GitHub - tiangolo/full-stack-fastapi-po...

fastapiを使わない人でも、これからSQLAlchemyを使用してコードを書く人には参考になるかもしれません。

やりたいこと

SQLAlchemyで定義したモデルにmany to manyの関係性を持つモデルが存在するとする。

many to manyの関係の場合に、具体的にどのように保存するのかのロジックが知りたい。

そこで、今回はUserモデルが複数のCompanyモデルに所属できることを考えてみたい。

many to many関係のモデルを作成する

今回テストで使用するモデルは以下です。

Companyモデルは、複数のUserを持てるため、relationshipを使用してusersを持っている。

SQLAlchemyの場合は、secondaryを使用してmany to manyの関係を示すマップテーブル(中間テーブル)を定義する。

# backend/app/app/models/company.py
from typing import TYPE_CHECKING

from sqlalchemy import (
    Boolean, Column, Integer, String,
    DateTime, func
    )
from sqlalchemy.orm import relationship

from app.db.base_class import Base

from app.models import company_user_map_table

if TYPE_CHECKING:
    from .user import User  # noqa: F401


class Company(Base):
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String(100), index=True)
    is_active = Column(Boolean(), default=True)
    created_at = Column(
        DateTime,
        server_default=func.now()
        )
    updated_at = Column(
        DateTime,
        server_default=func.now(),
        onupdate=func.now()
        )

    users = relationship(
        "User",
        secondary=company_user_map_table,
        back_populates="companies"
        )

Userモデルも同じく、複数のCompanyに所属できるため、relationshipを使用してcompaniesを持っている。

# backend/app/app/models/user.py
from typing import TYPE_CHECKING

from sqlalchemy import (
    Boolean, Column, Integer, String,
    DateTime, func
    )
from sqlalchemy.orm import relationship

from app.db.base_class import Base

from app.models import company_user_map_table

if TYPE_CHECKING:
    from .item import Item  # noqa: F401
    from .company import Company  # noqa: F401


class User(Base):
    id = Column(Integer, primary_key=True, index=True)
    full_name = Column(String(100), index=True)
    email = Column(String(100), unique=True, index=True, nullable=False)
    hashed_password = Column(String(100), nullable=False)
    is_active = Column(Boolean(), default=True)
    is_superuser = Column(Boolean(), default=False)
    created_at = Column(
        DateTime,
        server_default=func.now()
        )
    updated_at = Column(
        DateTime,
        server_default=func.now(),
        onupdate=func.now()
        )

    companies = relationship(
        "Company",
        secondary=company_user_map_table,
        back_populates="users"
        )
    items = relationship("Item", back_populates="owner")

最後のモデルは、UserモデルとCompanyモデルをつなぐためのマップテーブル(中間テーブル)。

SQLAlchemyでは以下のように書けます。

# backend/app/app/models/company_user_map.py
from sqlalchemy import (Table, ForeignKey, Column, Integer)

from app.db.base_class import Base


company_user_map_table = Table(
    'company_user_map',
    Base.metadata,
    Column('company_id', Integer, ForeignKey('company.id')),
    Column('user_id', Integer, ForeignKey('user.id'))
    )

Schemaを作成する

今回あまりCompanyのschemaは重要ではないのですが、次項のCRUDで一部使用しているので準備しました。

# backend/app/app/schemas/company.py
from typing import Optional

from pydantic import BaseModel


# Shared properties
class CompanyBase(BaseModel):
    name: Optional[str] = None
    is_active: Optional[bool] = True


# Properties to receive via API on creation
class CompanyCreate(CompanyBase):
    name: str


# Properties to receive via API on update
class CompanyUpdate(CompanyBase):
    pass


class CompanyInDBBase(CompanyBase):
    id: Optional[int] = None

    class Config:
        orm_mode = True


# Additional properties to return via API
class Company(CompanyInDBBase):
    pass


# Additional properties stored in DB
class CompanyInDB(CompanyInDBBase):
    pass

SchemaのUserクラスには、単にAPIで値を取得する際に、companiesが含まれるように追加してます。

# backend/app/app/schemas/user.py
from typing import Optional, List

from pydantic import BaseModel, EmailStr

from app.schemas.company import Company


# Shared properties
class UserBase(BaseModel):
    email: Optional[EmailStr] = None
    is_active: Optional[bool] = True
    is_superuser: bool = False
    full_name: Optional[str] = None


# Properties to receive via API on creation
class UserCreate(UserBase):
    email: EmailStr
    password: str


# Properties to receive via API on update
class UserUpdate(UserBase):
    password: Optional[str] = None


class UserInDBBase(UserBase):
    id: Optional[int] = None

    class Config:
        orm_mode = True


# Additional properties to return via API
class User(UserInDBBase):
    companies: List[Company]


# Additional properties stored in DB
class UserInDB(UserInDBBase):
    hashed_password: str

CRUDを作成する

今回は、APIでCompanyのIDを複数受け取って、そのIDのCompanyとUserを紐づけたいので、`get_by_ids`関数を用意しました。

# backend/app/app/crud/crud_company.py
from typing import List

from sqlalchemy.orm import Session

from app.crud.base import CRUDBase
from app.models.company import Company
from app.schemas.company import CompanyCreate, CompanyUpdate


class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]):
    def get_by_ids(
        self,
        db: Session,
        *,
        company_ids: List[int]
    ) -> List[Company]:
        return db.query(Company).filter(Company.id.in_(company_ids)).all()


company = CRUDCompany(Company)

UserのCRUDのほうは、passwordなしでPOSTした際にエラーが出たので、以下修正しました。

if update_data["password"]:

if "password" in update_data.keys() and update_data["password"]:

他はとくに触ってないです。

# backend/app/app/crud/crud_user.py
from typing import Any, Dict, Optional, Union

from sqlalchemy.orm import Session

from app.core.security import get_password_hash, verify_password
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate


class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
    .... 省略 ....

    def update(
        self,
        db: Session,
        *,
        db_obj: User,
        obj_in: Union[UserUpdate, Dict[str, Any]]
    ) -> User:
        if isinstance(obj_in, dict):
            update_data = obj_in
        else:
            update_data = obj_in.dict(exclude_unset=True)
        if "password" in update_data.keys() and update_data["password"]:
            hashed_password = get_password_hash(update_data["password"])
            del update_data["password"]
            update_data["hashed_password"] = hashed_password
        return super().update(db, db_obj=db_obj, obj_in=update_data)

    .... 省略 ....


user = CRUDUser(User)

多対多(many to many)の更新ロジック部分

更新ロジックは、☝でいろいろコードを追加したので、とてもシンプルに書くことができます。

以下を追加しただけです。

if company_ids is not None:
    companies = crud.company.get_by_ids(db, company_ids=company_ids)
    current_user.companies = companies

これで、マップテーブル(中間テーブル)にもデータを自動で書き込んでくれます。

※company, userデータ削除時にも、マップテーブルのデータは自動で削除されます。

# backend/app/app/api/api_v1/endpoints/users.py
@router.put("/me", response_model=schemas.User)
def update_user_me(
    *,
    db: Session = Depends(deps.get_db),
    password: str = Body(None),
    full_name: str = Body(None),
    email: EmailStr = Body(None),
    company_ids: List[int] = Body(None),
    current_user: models.User = Depends(deps.get_current_active_user),
) -> Any:
    """
    Update own user.
    """
    current_user_data = jsonable_encoder(current_user)
    user_in = schemas.UserUpdate(**current_user_data)
    if password is not None:
        user_in.password = password
    if full_name is not None:
        user_in.full_name = full_name
    if email is not None:
        user_in.email = email
    if company_ids is not None:
        companies = crud.company.get_by_ids(db, company_ids=company_ids)
        current_user.companies = companies
    user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
    return user

以上で、User又はCompany取得時にrelationshipで定義した項目を参照することができます。

# GET: http://192.168.XX.XX/api/v1/users/me
{
    "email": "test@example.com",
    "is_active": true,
    "is_superuser": false,
    "full_name": "tanaka taro",
    "id": 1,
    "companies": [
        {
            "name": "test company",
            "is_active": true,
            "id": 1
        }
    ]
}

参考

Flask sqlalchemy many-to-many insert data - Stack Overflow
タイトルとURLをコピーしました