以SaaS系统为例,租户直接数据应该是完全隔离的,当然隔离方式有两种:物理隔离和字段隔离
本文以字段隔离为例,通过租户字段完成租户数据的完全隔离,通过自定义async_session_factory的class来实现全局数据库操作的租户隔离
1. 定义base model
创建一个基本模型,这个基本模型是所有表中包含的字段,比如租户ID,创建时间,修改时间,修改人,备用字段等
from sqlalchemy import Column, String, event
from sqlalchemy.orm import declarative_base, Mapper, declared_attr
from app.database.tenant import TenantManager
Base = declarative_base()
class TenantAwareMixin(Base):
    __abstract__ = True
    """租户感知混入类"""
    tenant_id = Column(String(36), nullable=False, index=True, comment="租户ID")
    @classmethod
    def tenant_filter(cls):
        """租户过滤条件"""
        return cls.tenant_id == TenantManager.get_current_tenant_id()
    @declared_attr
    def tenant_id(cls):
        return Column(String(36), nullable=False, index=True, comment="租户ID")
@event.listens_for(Base, "init", propagate=True)
def _tenant_init(target, args, kwargs):
    """插入时自动设置租户ID"""
    if hasattr(target, "tenant_id") and not kwargs.get("tenant_id"):
        target.tenant_id = TenantManager.get_current_tenant_id()2. 修改其他model继承base model
class OtherTable(TenantAwareMixin):
    """其他表
    """
    __tablename__ = 'other_table'
    id = Column(Integer, nullable=False, primary_key=True)  # 主键3. 使用contextvars记录请求上下文中租户ID
可以增加拦截器在请求头中获取租户ID
current_tenant = contextvars.ContextVar("current_tenant", default=None)
current_tenant.set(获取的租户ID)4. 自定义Session Class 实现SQL全局拦截
import asyncio
from sqlalchemy import and_, text, Table
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import Select, Update, Delete
from .engine import engine
from .tenant import TenantManager
from ..models.base import TenantAwareMixin
current_tenant = contextvars.ContextVar("current_tenant", default=None)
class TenantAwareAsyncSession(AsyncSession):
    """租户感知会话"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.skip_tenant_filter = False
    def skip_tenant_filter_for_next_operation(self):
        self.skip_tenant_filter = True
    async def flush(self, *args, **kwargs):
        # 自动校验租户ID一致性
        if not self.skip_tenant_filter:
            for obj in self.new:
                if hasattr(obj, "tenant_id"):
                    current_tenant_id = current_tenant.get()
                    if obj.tenant_id != current_tenant_id:
                        raise ValueError("Tenant ID mismatch during flush")
        self.skip_tenant_filter = False
        await super().flush(*args, **kwargs)
    async def execute(self, statement, params=None, **kwargs):
        # 自动注入租户过滤
        if not self.skip_tenant_filter:
            statement = self._apply_tenant_filter(statement)
        self.skip_tenant_filter = False
        return await super().execute(statement, params, **kwargs)
    async def add(self, instance, _warn=False):
        # 自动注入租户过滤
        if not self.skip_tenant_filter:
            instance.tenant_id = current_tenant.get()
        else:
            self.skip_tenant_filter = False
        return super().add(instance, _warn)
    async def add_all(self, instances):
        # 自动注入租户过滤
        if not self.skip_tenant_filter:
            for instance in instances:
                instance.tenant_id = current_tenant.get()
        else:
            self.skip_tenant_filter = False
        return super().add_all(instances)
    def _add_tenant_condition(self, tables, tenant_id, conditions):
        for table in tables:
            if isinstance(table, Table) and hasattr(table.c, "tenant_id"):
                conditions.append(table.c.tenant_id == tenant_id)
            elif hasattr(table, 'get_children'):
                # 递归处理连接查询中的子表
                self._add_tenant_condition(table.get_children(), tenant_id, conditions)
    def _apply_tenant_filter(self, statement):
        """为所有查询添加租户过滤条件"""
        tenant_id = current_tenant.get()
        if isinstance(statement, Select):
            conditions = []
            # 处理基础表
            if hasattr(statement, 'table'):
                table = statement.table
                if hasattr(table.c, "tenant_id"):
                    conditions.append(table.c.tenant_id == tenant_id)
            # 处理连接查询
            elif hasattr(statement, 'froms'):
                self._add_tenant_condition(statement.froms, tenant_id, conditions)
            if conditions:
                statement = statement.where(and_(*conditions))
        # 处理UPDATE/DELETE
        elif isinstance(statement, (Update, Delete)):
            table = statement.table
            if hasattr(table.c, "tenant_id"):
                statement = statement.where(table.c.tenant_id == tenant_id)
        return statement
async_session_factory = sessionmaker(
    engine,
    class_=TenantAwareAsyncSession,
    autocommit=False,
    autoflush=True,
    expire_on_commit=False
)
# 定义获取当前任务的兼容性函数
def get_current_task():
    try:
        # Python 3.7及以上使用asyncio.current_task()
        return asyncio.current_task()
    except AttributeError:
        # Python 3.6及以下使用asyncio.Task.current_task()
        return asyncio.Task.current_task()
AsyncScopedSession = async_scoped_session(
    async_session_factory,
    scopefunc=lambda: get_current_task()
) 
             
           
             
                        
评论区