侧边栏壁纸
博主头像
ZHD的小窝博主等级

行动起来,活在当下

  • 累计撰写 79 篇文章
  • 累计创建 53 个标签
  • 累计收到 1 条评论

目 录CONTENT

文章目录

FastAPI+sqlalchemy 全局SQL处理

江南的风
2025-03-10 / 0 评论 / 0 点赞 / 14 阅读 / 6453 字 / 正在检测是否收录...

以SaaS系统为例,租户直接数据应该是完全隔离的,当然隔离方式有两种:物理隔离和字段隔离

本文以字段隔离为例,通过租户字段完成租户数据的完全隔离,通过自定义async_session_factory的class来实现全局数据库操作的租户隔离

1. 定义base model

创建一个基本模型,这个基本模型是所有表中包含的字段,比如租户ID,创建时间,修改时间,修改人,备用字段等

from sqlalchemy import Column, String, event
from sqlalchemy.orm import declarative_base, Mapper, declared_attr
from app.database.tenant import TenantManager

Base = declarative_base()

class TenantAwareMixin(Base):
    __abstract__ = True
    """租户感知混入类"""
    tenant_id = Column(String(36), nullable=False, index=True, comment="租户ID")

    @classmethod
    def tenant_filter(cls):
        """租户过滤条件"""
        return cls.tenant_id == TenantManager.get_current_tenant_id()

    @declared_attr
    def tenant_id(cls):
        return Column(String(36), nullable=False, index=True, comment="租户ID")

@event.listens_for(Base, "init", propagate=True)
def _tenant_init(target, args, kwargs):
    """插入时自动设置租户ID"""
    if hasattr(target, "tenant_id") and not kwargs.get("tenant_id"):
        target.tenant_id = TenantManager.get_current_tenant_id()

2. 修改其他model继承base model

class OtherTable(TenantAwareMixin):
    """其他表
    """

    __tablename__ = 'other_table'

    id = Column(Integer, nullable=False, primary_key=True)  # 主键

3. 使用contextvars记录请求上下文中租户ID

可以增加拦截器在请求头中获取租户ID

current_tenant = contextvars.ContextVar("current_tenant", default=None)
current_tenant.set(获取的租户ID)

4. 自定义Session Class 实现SQL全局拦截

import asyncio

from sqlalchemy import and_, text, Table
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import Select, Update, Delete

from .engine import engine
from .tenant import TenantManager
from ..models.base import TenantAwareMixin

current_tenant = contextvars.ContextVar("current_tenant", default=None)


class TenantAwareAsyncSession(AsyncSession):
    """租户感知会话"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.skip_tenant_filter = False

    def skip_tenant_filter_for_next_operation(self):
        self.skip_tenant_filter = True

    async def flush(self, *args, **kwargs):
        # 自动校验租户ID一致性
        if not self.skip_tenant_filter:
            for obj in self.new:
                if hasattr(obj, "tenant_id"):
                    current_tenant_id = current_tenant.get()
                    if obj.tenant_id != current_tenant_id:
                        raise ValueError("Tenant ID mismatch during flush")
        self.skip_tenant_filter = False
        await super().flush(*args, **kwargs)

    async def execute(self, statement, params=None, **kwargs):
        # 自动注入租户过滤
        if not self.skip_tenant_filter:
            statement = self._apply_tenant_filter(statement)
        self.skip_tenant_filter = False
        return await super().execute(statement, params, **kwargs)

    async def add(self, instance, _warn=False):
        # 自动注入租户过滤
        if not self.skip_tenant_filter:
            instance.tenant_id = current_tenant.get()
        else:
            self.skip_tenant_filter = False
        return super().add(instance, _warn)

    async def add_all(self, instances):
        # 自动注入租户过滤
        if not self.skip_tenant_filter:
            for instance in instances:
                instance.tenant_id = current_tenant.get()
        else:
            self.skip_tenant_filter = False
        return super().add_all(instances)

    def _add_tenant_condition(self, tables, tenant_id, conditions):
        for table in tables:
            if isinstance(table, Table) and hasattr(table.c, "tenant_id"):
                conditions.append(table.c.tenant_id == tenant_id)
            elif hasattr(table, 'get_children'):
                # 递归处理连接查询中的子表
                self._add_tenant_condition(table.get_children(), tenant_id, conditions)

    def _apply_tenant_filter(self, statement):
        """为所有查询添加租户过滤条件"""
        tenant_id = current_tenant.get()

        if isinstance(statement, Select):
            conditions = []
            # 处理基础表
            if hasattr(statement, 'table'):
                table = statement.table
                if hasattr(table.c, "tenant_id"):
                    conditions.append(table.c.tenant_id == tenant_id)
            # 处理连接查询
            elif hasattr(statement, 'froms'):
                self._add_tenant_condition(statement.froms, tenant_id, conditions)

            if conditions:
                statement = statement.where(and_(*conditions))

        # 处理UPDATE/DELETE
        elif isinstance(statement, (Update, Delete)):
            table = statement.table
            if hasattr(table.c, "tenant_id"):
                statement = statement.where(table.c.tenant_id == tenant_id)

        return statement


async_session_factory = sessionmaker(
    engine,
    class_=TenantAwareAsyncSession,
    autocommit=False,
    autoflush=True,
    expire_on_commit=False
)

# 定义获取当前任务的兼容性函数
def get_current_task():
    try:
        # Python 3.7及以上使用asyncio.current_task()
        return asyncio.current_task()
    except AttributeError:
        # Python 3.6及以下使用asyncio.Task.current_task()
        return asyncio.Task.current_task()

AsyncScopedSession = async_scoped_session(
    async_session_factory,
    scopefunc=lambda: get_current_task()
)

0

评论区