以SaaS系统为例,租户直接数据应该是完全隔离的,当然隔离方式有两种:物理隔离和字段隔离
本文以字段隔离为例,通过租户字段完成租户数据的完全隔离,通过自定义async_session_factory的class来实现全局数据库操作的租户隔离
1. 定义base model
创建一个基本模型,这个基本模型是所有表中包含的字段,比如租户ID,创建时间,修改时间,修改人,备用字段等
from sqlalchemy import Column, String, event
from sqlalchemy.orm import declarative_base, Mapper, declared_attr
from app.database.tenant import TenantManager
Base = declarative_base()
class TenantAwareMixin(Base):
__abstract__ = True
"""租户感知混入类"""
tenant_id = Column(String(36), nullable=False, index=True, comment="租户ID")
@classmethod
def tenant_filter(cls):
"""租户过滤条件"""
return cls.tenant_id == TenantManager.get_current_tenant_id()
@declared_attr
def tenant_id(cls):
return Column(String(36), nullable=False, index=True, comment="租户ID")
@event.listens_for(Base, "init", propagate=True)
def _tenant_init(target, args, kwargs):
"""插入时自动设置租户ID"""
if hasattr(target, "tenant_id") and not kwargs.get("tenant_id"):
target.tenant_id = TenantManager.get_current_tenant_id()
2. 修改其他model继承base model
class OtherTable(TenantAwareMixin):
"""其他表
"""
__tablename__ = 'other_table'
id = Column(Integer, nullable=False, primary_key=True) # 主键
3. 使用contextvars记录请求上下文中租户ID
可以增加拦截器在请求头中获取租户ID
current_tenant = contextvars.ContextVar("current_tenant", default=None)
current_tenant.set(获取的租户ID)
4. 自定义Session Class 实现SQL全局拦截
import asyncio
from sqlalchemy import and_, text, Table
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import Select, Update, Delete
from .engine import engine
from .tenant import TenantManager
from ..models.base import TenantAwareMixin
current_tenant = contextvars.ContextVar("current_tenant", default=None)
class TenantAwareAsyncSession(AsyncSession):
"""租户感知会话"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skip_tenant_filter = False
def skip_tenant_filter_for_next_operation(self):
self.skip_tenant_filter = True
async def flush(self, *args, **kwargs):
# 自动校验租户ID一致性
if not self.skip_tenant_filter:
for obj in self.new:
if hasattr(obj, "tenant_id"):
current_tenant_id = current_tenant.get()
if obj.tenant_id != current_tenant_id:
raise ValueError("Tenant ID mismatch during flush")
self.skip_tenant_filter = False
await super().flush(*args, **kwargs)
async def execute(self, statement, params=None, **kwargs):
# 自动注入租户过滤
if not self.skip_tenant_filter:
statement = self._apply_tenant_filter(statement)
self.skip_tenant_filter = False
return await super().execute(statement, params, **kwargs)
async def add(self, instance, _warn=False):
# 自动注入租户过滤
if not self.skip_tenant_filter:
instance.tenant_id = current_tenant.get()
else:
self.skip_tenant_filter = False
return super().add(instance, _warn)
async def add_all(self, instances):
# 自动注入租户过滤
if not self.skip_tenant_filter:
for instance in instances:
instance.tenant_id = current_tenant.get()
else:
self.skip_tenant_filter = False
return super().add_all(instances)
def _add_tenant_condition(self, tables, tenant_id, conditions):
for table in tables:
if isinstance(table, Table) and hasattr(table.c, "tenant_id"):
conditions.append(table.c.tenant_id == tenant_id)
elif hasattr(table, 'get_children'):
# 递归处理连接查询中的子表
self._add_tenant_condition(table.get_children(), tenant_id, conditions)
def _apply_tenant_filter(self, statement):
"""为所有查询添加租户过滤条件"""
tenant_id = current_tenant.get()
if isinstance(statement, Select):
conditions = []
# 处理基础表
if hasattr(statement, 'table'):
table = statement.table
if hasattr(table.c, "tenant_id"):
conditions.append(table.c.tenant_id == tenant_id)
# 处理连接查询
elif hasattr(statement, 'froms'):
self._add_tenant_condition(statement.froms, tenant_id, conditions)
if conditions:
statement = statement.where(and_(*conditions))
# 处理UPDATE/DELETE
elif isinstance(statement, (Update, Delete)):
table = statement.table
if hasattr(table.c, "tenant_id"):
statement = statement.where(table.c.tenant_id == tenant_id)
return statement
async_session_factory = sessionmaker(
engine,
class_=TenantAwareAsyncSession,
autocommit=False,
autoflush=True,
expire_on_commit=False
)
# 定义获取当前任务的兼容性函数
def get_current_task():
try:
# Python 3.7及以上使用asyncio.current_task()
return asyncio.current_task()
except AttributeError:
# Python 3.6及以下使用asyncio.Task.current_task()
return asyncio.Task.current_task()
AsyncScopedSession = async_scoped_session(
async_session_factory,
scopefunc=lambda: get_current_task()
)
评论区