在昨天的文章中,我们讨论了乐观锁(Optimistic Locking)作为高并发场景下保证数据一致性的重要手段。但乐观锁的实现,尤其是基于版本号(Version)或时间戳(Updated At)的 CAS (Compare-and-Swap) 模式,往往需要在应用的每个 Repository 中重复编写大量的样板代码。
今天的核心主题是:如何利用 Rust 过程宏的强大能力,将这些繁琐的持久化逻辑自动化,让开发者只需声明字段,即可获得健壮的乐观锁支持。
宏架构:分治与协作
实现一个完整的、自动化的乐观锁流程,需要宏在两个不同的代码层面进行注入和协作:
- 数据变更层 (
ActiveModelBehavior):负责在数据写入数据库前,自动管理版本号 (version) 和时间戳 (updated_at) 的递增/更新。 - 持久化操作层 (
Repository::save):负责实现核心的原子更新逻辑,即 CAS 检查。
Part 1: ActiveModel 的预处理钩子 (before_save)
这是我们实现乐观锁的第一步:确保在更新操作中,版本号能够正确地 自增。
我们通过宏注入或修改 sea-orm::ActiveModelBehavior Trait 的 before_save 钩子。
宏注入逻辑概览:
// 宏片段:insert_active_model_behavior_impl 的核心逻辑
if need_version {
let version_stmt = quote! {
if insert {
// 插入 (insert=true) 时,版本号初始化为 1
self.version = Set(1);
} else if self.is_changed() {
// 更新 (insert=false) 且模型有业务字段变化时,版本号自增
let current_version = match self.version {
Set(v) => *v,
_ => 0,
};
self.version = Set(current_version + 1);
}
};
}
// updated_at 逻辑类似:非插入且 is_changed 时设置为当前时间
关键成果:
当我们在 Repository 中执行更新操作时,ActiveModel 已经通过 before_save 确保了两个重要事实:
- 它携带着我们从数据库中读出的 旧版本号。
- 它将尝试写入的
version值,是 旧版本号 + 1。
Part 2: Repository 的原子 CAS 更新 (save 方法)
这是乐观锁实现的核心战场,由 fn create_tenant_save_impl 宏片段生成。其逻辑必须严格遵循 三步走 策略,以处理成功、冲突和首次插入三种情况。
Step 1: 原子 UPDATE (Compare-and-Swap)
我们使用 sea-orm 的 update_many 配合 filter 条件,来实现原子性检查。
我们从聚合根 (entity) 中取出 旧版本(即 current_version),并将其作为 WHERE 子句的一部分。
// 宏片段:create_tenant_save_impl 的核心 CAS 逻辑
// 从聚合获取当前版本(即期望的旧版本)
let current_version = entity_model.#optimistic_lock_field_ident();
// 1) 原子 UPDATE(带 version CAS)
let res = models::Entity::update_many()
#id_filters // 主键和 TenantId 过滤
// ⬇️ 核心:只有当数据库中的版本号等于旧版本号时,才允许更新 ⬇️
.filter(models::Column::#optimistic_lock_col_ident.eq(current_version))
.set(update_model.clone())
.exec(&conn)
.await?;
if res.rows_affected > 0 {
// 成功!说明版本匹配,且更新成功写入
// ... 事件处理并返回 Ok(())
return Ok(());
}
如果 rows_affected > 0,任务圆满完成。如果 rows_affected == 0,则进入下一步判断。
Step 2 & 3: 冲突检测与首次插入
如果 CAS 更新失败(rows_affected == 0),我们需要区分是 版本冲突(记录存在但版本号不匹配)还是 首次插入(记录根本不存在)。
// 2) UPDATE 未命中,检查记录是否存在
if models::Entity::find()
#id_filters // 仅按主键和 TenantId 查找
.one(&conn)
.await?
.is_some()
{
// 记录存在,但 Step 1 未命中 -> 乐观锁冲突!
return Err(#crate_root::domain::RepositoryError::optimistic_lock_error(
"Optimistic lock conflict: Version mismatch".to_string(),
));
}
// 3) 记录不存在,执行首次插入
let insert_model: models::Model = entity_model.clone().try_into()?;
let mut active_model = insert_model.into_active_model();
active_model.insert(&conn).await?;
// ... 事件处理并返回 Ok(())
Talk is cheap, show me the code
before_save
fn insert_active_model_behavior_impl(input: &mut ItemMod, model_config: &ModelConfig) {
let Some((_, items)) = &mut input.content else {
return;
};
let mut has_active_model_behavior = false;
for item in items.iter_mut() {
if let syn::Item::Impl(item_impl) = item
&& let Some((_, path, _)) = &item_impl.trait_
&& path.segments.last().unwrap().ident == "ActiveModelBehavior"
{
has_active_model_behavior = true;
break;
}
}
if !has_active_model_behavior {
let active_model_behavior_impl = quote! {
#[async_trait]
impl ActiveModelBehavior for ActiveModel {
async fn before_save<C>(mut self, db: &C, insert: bool) -> Result<Self, DbErr>
where
C: ConnectionTrait,
{
Ok(self)
}
}
};
items.push(parse_quote!(#active_model_behavior_impl));
}
for item in items.iter_mut() {
if let syn::Item::Impl(item_impl) = item
&& let Some((_, path, _)) = &item_impl.trait_
&& path.segments.last().unwrap().ident == "ActiveModelBehavior"
{
let mut has_before_save = false;
for item in item_impl.items.iter_mut() {
if let syn::ImplItem::Fn(method) = item
&& method.sig.ident == "before_save"
{
has_before_save = true;
break;
}
}
if !has_before_save {
let before_save_method = quote! {
async fn before_save<C>(mut self, db: &C, insert: bool) -> Result<Self, DbErr>
where
C: ConnectionTrait,
{
Ok(self)
}
};
item_impl.items.push(parse_quote!(#before_save_method));
}
let need_created_at = model_config
.fields
.iter()
.any(|f| f.ident.as_ref().unwrap() == "created_at");
let need_updated_at = model_config
.fields
.iter()
.any(|f| f.ident.as_ref().unwrap() == "updated_at");
let need_version = model_config
.fields
.iter()
.any(|f| f.ident.as_ref().unwrap() == "version");
if !(need_created_at || need_updated_at || need_version) {
return;
}
for item in item_impl.items.iter_mut() {
if let syn::ImplItem::Fn(method) = item
&& method.sig.ident == "before_save"
{
let mut stmts = Vec::new();
stmts.push(quote! {
let now = chrono::Utc::now();
});
if need_created_at {
let created_at_stmt = quote! {
if insert {
self.created_at = Set(now);
}
};
stmts.push(created_at_stmt);
}
if need_updated_at {
let updated_at_stmt = quote! {
if insert {
self.updated_at = Set(now);
} else if self.is_changed() {
self.updated_at = Set(now);
}
};
stmts.push(updated_at_stmt);
}
if need_version {
let version_stmt = quote! {
if insert {
self.version = Set(1);
} else if self.is_changed() {
let current_version = match self.version {
Set(v) => *v,
_ => 0,
};
self.version = Set(current_version + 1);
}
};
stmts.push(version_stmt);
}
let stmts = parse_quote!({#(#stmts)*});
// 插入到方法体的开头
method.block.stmts.insert(0, stmts);
}
}
}
}
}
宏生成的代码示例
impl ActiveModelBehavior for ActiveModel {
#[allow(
elided_named_lifetimes,
clippy::async_yields_async,
clippy::diverging_sub_expression,
clippy::let_unit_value,
clippy::needless_arbitrary_self_type,
clippy::no_effect_underscore_binding,
clippy::shadow_same,
clippy::type_complexity,
clippy::type_repetition_in_bounds,
clippy::used_underscore_binding
)]
fn before_save<'life0, 'async_trait, C>(
self,
db: &'life0 C,
insert: bool,
) -> ::core::pin::Pin<
Box<
dyn ::core::future::Future<Output = Result<Self, DbErr>>
+ ::core::marker::Send
+ 'async_trait,
>,
>
where
C: ConnectionTrait,
C: 'async_trait,
'life0: 'async_trait,
Self: 'async_trait,
{
Box::pin(async move {
if let ::core::option::Option::Some(__ret) =
::core::option::Option::None::<Result<Self, DbErr>>
{
#[allow(unreachable_code)]
return __ret;
}
let mut __self = self;
let insert = insert;
let __ret: Result<Self, DbErr> = {
{
let now = chrono::Utc::now();
if insert {
__self.created_at = Set(now);
}
if insert {
__self.updated_at = Set(now);
} else if __self.is_changed() {
__self.updated_at = Set(now);
}
}
Ok(__self)
};
#[allow(unreachable_code)]
__ret
})
}
}
Repository::save
fn create_tenant_save_impl(
crate_root: &Path,
aggregate: &Path,
args: &RepositoryStructArgs,
id_filters: &TokenStream,
) -> TokenStream {
// 若指定了乐观锁字段,准备字段名/Column ident
let optimistic_lock_field = args.optimistic_lock_field.as_ref().map(|lit| {
let optimistic_lock_field_name = lit.value();
let optimstic_lock_field_ident = new_id(&optimistic_lock_field_name); // 用于 ActiveModel/Model 字段访问
let optimistic_lock_col_ident = new_id(&to_pascal_case(&optimistic_lock_field_name)); // 用于 models::Column::Xxx
(
optimistic_lock_field_name,
optimstic_lock_field_ident,
optimistic_lock_col_ident,
)
});
// 根据是否指定乐观锁字段,生成 save 的实现
if let Some((
optimistic_lock_field_name,
optimistic_lock_field_ident,
optimistic_lock_col_ident,
)) = optimistic_lock_field
{
if optimistic_lock_field_name == "version" {
quote! {
async fn save(
&self,
txn: &mut TC,
entity: &mut #crate_root::domain::EventSourcedEntity<#aggregate>,
) -> Result<(), #crate_root::domain::RepositoryError> {
use #crate_root::domain::SeaOrmModelUpdater;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
use sea_orm::ActiveValue::Set;
let conn = txn.get_connection();
let entity_model: &#aggregate = entity;
let id = entity_model.id();
let tenant_id = entity_model.tenant_id();
// 从聚合获取当前版本与期望旧版本
let current_version = entity_model.#optimistic_lock_field_ident();
// 构造用于原子更新的 ActiveModel(只写回必要列)
let mut update_model = models::Model::from(entity_model.clone()).into_active_model();
// 1) 原子 UPDATE(带 version CAS)
let res = models::Entity::update_many()
#id_filters
.filter(models::Column::TenantId.eq(*tenant_id))
.filter(models::Column::#optimistic_lock_col_ident.eq(current_version))
.set(update_model.clone())
.exec(&conn)
.await?;
if res.rows_affected > 0 {
entity.move_event_to_context(txn);
return Ok(());
}
// 2) UPDATE 未命中,检查记录是否存在(按主键 + tenant)
if models::Entity::find()
#id_filters
.filter(models::Column::TenantId.eq(*tenant_id))
.one(&conn)
.await?
.is_some()
{
return Err(#crate_root::domain::RepositoryError::optimistic_lock_error(
"Optimistic lock conflict: Version mismatch".to_string(),
));
}
// 3) 记录不存在,插入数据
let insert_model: models::Model = entity_model.clone().try_into()?;
let mut active_model = insert_model.into_active_model();
active_model.insert(&conn).await?;
entity.move_event_to_context(txn);
Ok(())
}
}
} else {
// treat as timestamp update_at
quote! {
async fn save(
&self,
txn: &mut TC,
entity: &mut #crate_root::domain::EventSourcedEntity<#aggregate>,
) -> Result<(), #crate_root::domain::RepositoryError> {
use #crate_root::domain::SeaOrmModelUpdater;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
use sea_orm::ActiveValue::Set;
use chrono::Utc;
let conn = txn.get_connection();
let entity_model: &#aggregate = entity;
let id = entity_model.id();
let tenant_id = entity_model.tenant_id();
// 读取实体携带的旧时间戳与准备新的时间戳
let current_ts = entity_model.#optimistic_lock_field_ident();
// 构造用于原子更新的 ActiveModel
let mut update_model = models::Model::from(entity_model.clone()).into_active_model();
// 1) 原子 UPDATE(带 updated_at CAS)
let res = models::Entity::update_many()
#id_filters
.filter(models::Column::TenantId.eq(*tenant_id))
.filter(models::Column::#optimistic_lock_col_ident.eq(current_ts))
.set(update_model.clone())
.exec(&conn)
.await?;
if res.rows_affected > 0 {
entity.move_event_to_context(txn);
return Ok(());
}
// 2) UPDATE 未命中,检查记录是否存在
if models::Entity::find()
#id_filters
.filter(models::Column::TenantId.eq(*tenant_id))
.one(&conn)
.await?
.is_some()
{
return Err(#crate_root::domain::RepositoryError::optimistic_lock_error(
"Optimistic lock conflict".to_string(),
));
}
// 3) 记录不存在,直接插入数据
let insert_model: models::Model = entity_model.clone().try_into()?;
let mut active_model = insert_model.into_active_model();
active_model.insert(&conn).await?;
entity.move_event_to_context(txn);
Ok(())
}
}
}
} else {
// no optimistic lock field -> simple update/insert behavior (原始实现)
quote! {
async fn save(
&self,
txn: &mut TC,
entity: &mut #crate_root::domain::EventSourcedEntity<#aggregate>,
) -> Result<(), #crate_root::domain::RepositoryError> {
use #crate_root::domain::SeaOrmModelUpdater;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
let conn = txn.get_connection();
let entity_model: &#aggregate = entity;
let id = entity_model.id();
let tenant_id = entity_model.tenant_id();
if let Some(mut model) = models::Entity::find()
#id_filters
.filter(models::Column::TenantId.eq(*tenant_id))
.one(&conn)
.await?
{
if &model.tenant_id != tenant_id {
return Err(#crate_root::domain::RepositoryError::mapping_error(
format!(
"Tenant ID mismatch: expected {}, found {}, id: {}",
tenant_id, model.tenant_id, id
),
));
}
// 更新逻辑
model.update_from_aggregate_root(entity_model).await?;
let active_model = model.into_active_model();
active_model.update(&conn).await?;
} else {
// 创建新记录
let model: models::Model = entity_model.clone().try_into()?;
let active_model = model.into_active_model();
active_model.insert(&conn).await?;
}
entity.move_event_to_context(txn);
Ok(())
}
}
}
}
宏生成的代码示例
async fn save(
&self,
txn: &mut TC,
entity: &mut core_common::domain::EventSourcedEntity<TenantUser>,
) -> Result<(), core_common::domain::RepositoryError> {
use core_common::domain::SeaOrmModelUpdater;
use sea_orm::{
ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter,
};
use sea_orm::ActiveValue::Set;
use chrono::Utc;
let conn = txn.get_connection();
let entity_model: &TenantUser = entity;
let id = entity_model.id();
let current_ts = entity_model.update_at();
let mut update_model = models::Model::from(entity_model.clone())
.into_active_model();
let res = models::Entity::update_many()
.filter(models::Column::Id.eq((id.tenant_id(), id.user_id())))
.filter(models::Column::UpdateAt.eq(current_ts))
.set(update_model.clone())
.exec(&conn)
.await?;
if res.rows_affected > 0 {
entity.move_event_to_context(txn);
return Ok(());
}
if models::Entity::find()
.filter(models::Column::Id.eq((id.tenant_id(), id.user_id())))
.one(&conn)
.await?
.is_some()
{
return Err(
core_common::domain::RepositoryError::optimistic_lock_error(
"Optimistic lock conflict".to_string(),
),
);
}
let insert_model: models::Model = entity_model.clone().try_into()?;
let mut active_model = insert_model.into_active_model();
active_model.insert(&conn).await?;
entity.move_event_to_context(txn);
Ok(())
}
兼容性处理
宏的另一个优势是其灵活性。它能根据字段名称自动适配不同的乐观锁策略:
- 如果检测到字段为
"version",则执行版本号的 CAS 逻辑。 - 如果检测到其他时间戳字段如
"updated_at",则执行基于时间戳的 CAS 逻辑。
结论
通过将 before_save 中的版本递增逻辑,与 Repository::save 中的原子 CAS 检查完美结合,我们使用 Rust 过程宏实现了一个 高内聚、低耦合 的乐观锁基础设施。
开发者现在可以专注于业务逻辑,而将并发控制的复杂性和样板代码完全交给宏来处理。这不仅极大地提高了开发效率,同时也确保了底层持久化操作的健壮性和一致性。
评论区
写评论还没有评论