SQLAlchemyでの多対多(many to many)の保存方法

Web開発

みなさん、こんにちは!

本日は、SQLAlchemyでの多対多(many to many)の保存方法について記載します。

私の場合、今回試すものと条件は異なるのですが、少しはまりました。。ので。参考になれば幸いです。

SQLAlchemyのモデル定義についてはこちらで書いてます。

今回も☝のリンク先の記事と同じように、こちらのfastapiのテンプレートを使用してます。

tiangolo/full-stack-fastapi-postgresql
Full stack, modern web application generator. Using FastAPI, PostgreSQL as database, Docker, automatic HTTPS and more. - tiangolo/full-stack-fastapi-postgresql

fastapiを使わない人でも、これからSQLAlchemyを使用してコードを書く人には参考になるかもしれません。

やりたいこと

SQLAlchemyで定義したモデルにmany to manyの関係性を持つモデルが存在するとする。

many to manyの関係の場合に、具体的にどのように保存するのかのロジックが知りたい。

そこで、今回はUserモデルが複数のCompanyモデルに所属できることを考えてみたい。

many to many関係のモデルを作成する

今回テストで使用するモデルは以下です。

Companyモデルは、複数のUserを持てるため、relationshipを使用してusersを持っている。

SQLAlchemyの場合は、secondaryを使用してmany to manyの関係を示すマップテーブル(中間テーブル)を定義する。

# backend/app/app/models/company.py
from typing import TYPE_CHECKING

from sqlalchemy import (
    Boolean, Column, Integer, String,
    DateTime, func
    )
from sqlalchemy.orm import relationship

from app.db.base_class import Base

from app.models import company_user_map_table

if TYPE_CHECKING:
    from .user import User  # noqa: F401


class Company(Base):
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String(100), index=True)
    is_active = Column(Boolean(), default=True)
    created_at = Column(
        DateTime,
        server_default=func.now()
        )
    updated_at = Column(
        DateTime,
        server_default=func.now(),
        onupdate=func.now()
        )

    users = relationship(
        "User",
        secondary=company_user_map_table,
        back_populates="companies"
        )

Userモデルも同じく、複数のCompanyに所属できるため、relationshipを使用してcompaniesを持っている。

# backend/app/app/models/user.py
from typing import TYPE_CHECKING

from sqlalchemy import (
    Boolean, Column, Integer, String,
    DateTime, func
    )
from sqlalchemy.orm import relationship

from app.db.base_class import Base

from app.models import company_user_map_table

if TYPE_CHECKING:
    from .item import Item  # noqa: F401
    from .company import Company  # noqa: F401


class User(Base):
    id = Column(Integer, primary_key=True, index=True)
    full_name = Column(String(100), index=True)
    email = Column(String(100), unique=True, index=True, nullable=False)
    hashed_password = Column(String(100), nullable=False)
    is_active = Column(Boolean(), default=True)
    is_superuser = Column(Boolean(), default=False)
    created_at = Column(
        DateTime,
        server_default=func.now()
        )
    updated_at = Column(
        DateTime,
        server_default=func.now(),
        onupdate=func.now()
        )

    companies = relationship(
        "Company",
        secondary=company_user_map_table,
        back_populates="users"
        )
    items = relationship("Item", back_populates="owner")

最後のモデルは、UserモデルとCompanyモデルをつなぐためのマップテーブル(中間テーブル)。

SQLAlchemyでは以下のように書けます。

# backend/app/app/models/company_user_map.py
from sqlalchemy import (Table, ForeignKey, Column, Integer)

from app.db.base_class import Base


company_user_map_table = Table(
    'company_user_map',
    Base.metadata,
    Column('company_id', Integer, ForeignKey('company.id')),
    Column('user_id', Integer, ForeignKey('user.id'))
    )

Schemaを作成する

今回あまりCompanyのschemaは重要ではないのですが、次項のCRUDで一部使用しているので準備しました。

# backend/app/app/schemas/company.py
from typing import Optional

from pydantic import BaseModel


# Shared properties
class CompanyBase(BaseModel):
    name: Optional[str] = None
    is_active: Optional[bool] = True


# Properties to receive via API on creation
class CompanyCreate(CompanyBase):
    name: str


# Properties to receive via API on update
class CompanyUpdate(CompanyBase):
    pass


class CompanyInDBBase(CompanyBase):
    id: Optional[int] = None

    class Config:
        orm_mode = True


# Additional properties to return via API
class Company(CompanyInDBBase):
    pass


# Additional properties stored in DB
class CompanyInDB(CompanyInDBBase):
    pass

SchemaのUserクラスには、単にAPIで値を取得する際に、companiesが含まれるように追加してます。

# backend/app/app/schemas/user.py
from typing import Optional, List

from pydantic import BaseModel, EmailStr

from app.schemas.company import Company


# Shared properties
class UserBase(BaseModel):
    email: Optional[EmailStr] = None
    is_active: Optional[bool] = True
    is_superuser: bool = False
    full_name: Optional[str] = None


# Properties to receive via API on creation
class UserCreate(UserBase):
    email: EmailStr
    password: str


# Properties to receive via API on update
class UserUpdate(UserBase):
    password: Optional[str] = None


class UserInDBBase(UserBase):
    id: Optional[int] = None

    class Config:
        orm_mode = True


# Additional properties to return via API
class User(UserInDBBase):
    companies: List[Company]


# Additional properties stored in DB
class UserInDB(UserInDBBase):
    hashed_password: str

CRUDを作成する

今回は、APIでCompanyのIDを複数受け取って、そのIDのCompanyとUserを紐づけたいので、`get_by_ids`関数を用意しました。

# backend/app/app/crud/crud_company.py
from typing import List

from sqlalchemy.orm import Session

from app.crud.base import CRUDBase
from app.models.company import Company
from app.schemas.company import CompanyCreate, CompanyUpdate


class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]):
    def get_by_ids(
        self,
        db: Session,
        *,
        company_ids: List[int]
    ) -> List[Company]:
        return db.query(Company).filter(Company.id.in_(company_ids)).all()


company = CRUDCompany(Company)

UserのCRUDのほうは、passwordなしでPOSTした際にエラーが出たので、以下修正しました。

if update_data["password"]:

if "password" in update_data.keys() and update_data["password"]:

他はとくに触ってないです。

# backend/app/app/crud/crud_user.py
from typing import Any, Dict, Optional, Union

from sqlalchemy.orm import Session

from app.core.security import get_password_hash, verify_password
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate


class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
    .... 省略 ....

    def update(
        self,
        db: Session,
        *,
        db_obj: User,
        obj_in: Union[UserUpdate, Dict[str, Any]]
    ) -> User:
        if isinstance(obj_in, dict):
            update_data = obj_in
        else:
            update_data = obj_in.dict(exclude_unset=True)
        if "password" in update_data.keys() and update_data["password"]:
            hashed_password = get_password_hash(update_data["password"])
            del update_data["password"]
            update_data["hashed_password"] = hashed_password
        return super().update(db, db_obj=db_obj, obj_in=update_data)

    .... 省略 ....


user = CRUDUser(User)

多対多(many to many)の更新ロジック部分

更新ロジックは、☝でいろいろコードを追加したので、とてもシンプルに書くことができます。

以下を追加しただけです。

if company_ids is not None:
    companies = crud.company.get_by_ids(db, company_ids=company_ids)
    current_user.companies = companies

これで、マップテーブル(中間テーブル)にもデータを自動で書き込んでくれます。

※company, userデータ削除時にも、マップテーブルのデータは自動で削除されます。

# backend/app/app/api/api_v1/endpoints/users.py
@router.put("/me", response_model=schemas.User)
def update_user_me(
    *,
    db: Session = Depends(deps.get_db),
    password: str = Body(None),
    full_name: str = Body(None),
    email: EmailStr = Body(None),
    company_ids: List[int] = Body(None),
    current_user: models.User = Depends(deps.get_current_active_user),
) -> Any:
    """
    Update own user.
    """
    current_user_data = jsonable_encoder(current_user)
    user_in = schemas.UserUpdate(**current_user_data)
    if password is not None:
        user_in.password = password
    if full_name is not None:
        user_in.full_name = full_name
    if email is not None:
        user_in.email = email
    if company_ids is not None:
        companies = crud.company.get_by_ids(db, company_ids=company_ids)
        current_user.companies = companies
    user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
    return user

以上で、User又はCompany取得時にrelationshipで定義した項目を参照することができます。

# GET: http://192.168.XX.XX/api/v1/users/me
{
    "email": "test@example.com",
    "is_active": true,
    "is_superuser": false,
    "full_name": "tanaka taro",
    "id": 1,
    "companies": [
        {
            "name": "test company",
            "is_active": true,
            "id": 1
        }
    ]
}

参考

Flask sqlalchemy many-to-many insert data
I am trying to make a many to many relation here in Flask-SQLAlchemy, but it seems that I don't know how to fill the "many to many identifier database". Could y...
タイトルとURLをコピーしました