みなさん、こんにちは!
本日は、SQLAlchemyでの多対多(many to many)の保存方法について記載します。
私の場合、今回試すものと条件は異なるのですが、少しはまりました。。ので。参考になれば幸いです。
SQLAlchemyのモデル定義についてはこちらで書いてます。
今回も☝のリンク先の記事と同じように、こちらのfastapiのテンプレートを使用してます。
fastapiを使わない人でも、これからSQLAlchemyを使用してコードを書く人には参考になるかもしれません。
やりたいこと
SQLAlchemyで定義したモデルにmany to manyの関係性を持つモデルが存在するとする。
many to manyの関係の場合に、具体的にどのように保存するのかのロジックが知りたい。
そこで、今回はUserモデルが複数のCompanyモデルに所属できることを考えてみたい。
many to many関係のモデルを作成する
今回テストで使用するモデルは以下です。
Companyモデルは、複数のUserを持てるため、relationshipを使用してusersを持っている。
SQLAlchemyの場合は、secondaryを使用してmany to manyの関係を示すマップテーブル(中間テーブル)を定義する。
# backend/app/app/models/company.py
from typing import TYPE_CHECKING
from sqlalchemy import (
Boolean, Column, Integer, String,
DateTime, func
)
from sqlalchemy.orm import relationship
from app.db.base_class import Base
from app.models import company_user_map_table
if TYPE_CHECKING:
from .user import User # noqa: F401
class Company(Base):
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), index=True)
is_active = Column(Boolean(), default=True)
created_at = Column(
DateTime,
server_default=func.now()
)
updated_at = Column(
DateTime,
server_default=func.now(),
onupdate=func.now()
)
users = relationship(
"User",
secondary=company_user_map_table,
back_populates="companies"
)
Userモデルも同じく、複数のCompanyに所属できるため、relationshipを使用してcompaniesを持っている。
# backend/app/app/models/user.py
from typing import TYPE_CHECKING
from sqlalchemy import (
Boolean, Column, Integer, String,
DateTime, func
)
from sqlalchemy.orm import relationship
from app.db.base_class import Base
from app.models import company_user_map_table
if TYPE_CHECKING:
from .item import Item # noqa: F401
from .company import Company # noqa: F401
class User(Base):
id = Column(Integer, primary_key=True, index=True)
full_name = Column(String(100), index=True)
email = Column(String(100), unique=True, index=True, nullable=False)
hashed_password = Column(String(100), nullable=False)
is_active = Column(Boolean(), default=True)
is_superuser = Column(Boolean(), default=False)
created_at = Column(
DateTime,
server_default=func.now()
)
updated_at = Column(
DateTime,
server_default=func.now(),
onupdate=func.now()
)
companies = relationship(
"Company",
secondary=company_user_map_table,
back_populates="users"
)
items = relationship("Item", back_populates="owner")
最後のモデルは、UserモデルとCompanyモデルをつなぐためのマップテーブル(中間テーブル)。
SQLAlchemyでは以下のように書けます。
# backend/app/app/models/company_user_map.py
from sqlalchemy import (Table, ForeignKey, Column, Integer)
from app.db.base_class import Base
company_user_map_table = Table(
'company_user_map',
Base.metadata,
Column('company_id', Integer, ForeignKey('company.id')),
Column('user_id', Integer, ForeignKey('user.id'))
)
Schemaを作成する
今回あまりCompanyのschemaは重要ではないのですが、次項のCRUDで一部使用しているので準備しました。
# backend/app/app/schemas/company.py
from typing import Optional
from pydantic import BaseModel
# Shared properties
class CompanyBase(BaseModel):
name: Optional[str] = None
is_active: Optional[bool] = True
# Properties to receive via API on creation
class CompanyCreate(CompanyBase):
name: str
# Properties to receive via API on update
class CompanyUpdate(CompanyBase):
pass
class CompanyInDBBase(CompanyBase):
id: Optional[int] = None
class Config:
orm_mode = True
# Additional properties to return via API
class Company(CompanyInDBBase):
pass
# Additional properties stored in DB
class CompanyInDB(CompanyInDBBase):
pass
SchemaのUserクラスには、単にAPIで値を取得する際に、companiesが含まれるように追加してます。
# backend/app/app/schemas/user.py
from typing import Optional, List
from pydantic import BaseModel, EmailStr
from app.schemas.company import Company
# Shared properties
class UserBase(BaseModel):
email: Optional[EmailStr] = None
is_active: Optional[bool] = True
is_superuser: bool = False
full_name: Optional[str] = None
# Properties to receive via API on creation
class UserCreate(UserBase):
email: EmailStr
password: str
# Properties to receive via API on update
class UserUpdate(UserBase):
password: Optional[str] = None
class UserInDBBase(UserBase):
id: Optional[int] = None
class Config:
orm_mode = True
# Additional properties to return via API
class User(UserInDBBase):
companies: List[Company]
# Additional properties stored in DB
class UserInDB(UserInDBBase):
hashed_password: str
CRUDを作成する
今回は、APIでCompanyのIDを複数受け取って、そのIDのCompanyとUserを紐づけたいので、`get_by_ids`関数を用意しました。
# backend/app/app/crud/crud_company.py
from typing import List
from sqlalchemy.orm import Session
from app.crud.base import CRUDBase
from app.models.company import Company
from app.schemas.company import CompanyCreate, CompanyUpdate
class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]):
def get_by_ids(
self,
db: Session,
*,
company_ids: List[int]
) -> List[Company]:
return db.query(Company).filter(Company.id.in_(company_ids)).all()
company = CRUDCompany(Company)
UserのCRUDのほうは、passwordなしでPOSTした際にエラーが出たので、以下修正しました。
if update_data["password"]:
↓
if "password" in update_data.keys() and update_data["password"]:
他はとくに触ってないです。
# backend/app/app/crud/crud_user.py
from typing import Any, Dict, Optional, Union
from sqlalchemy.orm import Session
from app.core.security import get_password_hash, verify_password
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
.... 省略 ....
def update(
self,
db: Session,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
if "password" in update_data.keys() and update_data["password"]:
hashed_password = get_password_hash(update_data["password"])
del update_data["password"]
update_data["hashed_password"] = hashed_password
return super().update(db, db_obj=db_obj, obj_in=update_data)
.... 省略 ....
user = CRUDUser(User)
多対多(many to many)の更新ロジック部分
更新ロジックは、☝でいろいろコードを追加したので、とてもシンプルに書くことができます。
以下を追加しただけです。
if company_ids is not None:
companies = crud.company.get_by_ids(db, company_ids=company_ids)
current_user.companies = companies
これで、マップテーブル(中間テーブル)にもデータを自動で書き込んでくれます。
※company, userデータ削除時にも、マップテーブルのデータは自動で削除されます。
# backend/app/app/api/api_v1/endpoints/users.py
@router.put("/me", response_model=schemas.User)
def update_user_me(
*,
db: Session = Depends(deps.get_db),
password: str = Body(None),
full_name: str = Body(None),
email: EmailStr = Body(None),
company_ids: List[int] = Body(None),
current_user: models.User = Depends(deps.get_current_active_user),
) -> Any:
"""
Update own user.
"""
current_user_data = jsonable_encoder(current_user)
user_in = schemas.UserUpdate(**current_user_data)
if password is not None:
user_in.password = password
if full_name is not None:
user_in.full_name = full_name
if email is not None:
user_in.email = email
if company_ids is not None:
companies = crud.company.get_by_ids(db, company_ids=company_ids)
current_user.companies = companies
user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
return user
以上で、User又はCompany取得時にrelationshipで定義した項目を参照することができます。
# GET: http://192.168.XX.XX/api/v1/users/me
{
"email": "test@example.com",
"is_active": true,
"is_superuser": false,
"full_name": "tanaka taro",
"id": 1,
"companies": [
{
"name": "test company",
"is_active": true,
"id": 1
}
]
}