< 返回我的博客

blossom-001 发表于 2025-11-18 20:33

在昨天的文章中,我们讨论了乐观锁(Optimistic Locking)作为高并发场景下保证数据一致性的重要手段。但乐观锁的实现,尤其是基于版本号(Version)或时间戳(Updated At)的 CAS (Compare-and-Swap) 模式,往往需要在应用的每个 Repository 中重复编写大量的样板代码。

今天的核心主题是:如何利用 Rust 过程宏的强大能力,将这些繁琐的持久化逻辑自动化,让开发者只需声明字段,即可获得健壮的乐观锁支持。


宏架构:分治与协作

实现一个完整的、自动化的乐观锁流程,需要宏在两个不同的代码层面进行注入和协作:

  1. 数据变更层 (ActiveModelBehavior):负责在数据写入数据库前,自动管理版本号 (version) 和时间戳 (updated_at) 的递增/更新。
  2. 持久化操作层 (Repository::save):负责实现核心的原子更新逻辑,即 CAS 检查

Part 1: ActiveModel 的预处理钩子 (before_save)

这是我们实现乐观锁的第一步:确保在更新操作中,版本号能够正确地 自增

我们通过宏注入或修改 sea-orm::ActiveModelBehavior Trait 的 before_save 钩子。

宏注入逻辑概览:

// 宏片段:insert_active_model_behavior_impl 的核心逻辑
if need_version {
    let version_stmt = quote! {
        if insert {
            // 插入 (insert=true) 时,版本号初始化为 1
            self.version = Set(1);
        } else if self.is_changed() {
            // 更新 (insert=false) 且模型有业务字段变化时,版本号自增
            let current_version = match self.version {
                Set(v) => *v,
                _ => 0,
            };
            self.version = Set(current_version + 1);
        }
    };
}
// updated_at 逻辑类似:非插入且 is_changed 时设置为当前时间

关键成果: 当我们在 Repository 中执行更新操作时,ActiveModel 已经通过 before_save 确保了两个重要事实:

  1. 它携带着我们从数据库中读出的 旧版本号
  2. 它将尝试写入的 version 值,是 旧版本号 + 1

Part 2: Repository 的原子 CAS 更新 (save 方法)

这是乐观锁实现的核心战场,由 fn create_tenant_save_impl 宏片段生成。其逻辑必须严格遵循 三步走 策略,以处理成功、冲突和首次插入三种情况。

Step 1: 原子 UPDATE (Compare-and-Swap)

我们使用 sea-ormupdate_many 配合 filter 条件,来实现原子性检查。

我们从聚合根 (entity) 中取出 旧版本(即 current_version),并将其作为 WHERE 子句的一部分。

// 宏片段:create_tenant_save_impl 的核心 CAS 逻辑

// 从聚合获取当前版本(即期望的旧版本)
let current_version = entity_model.#optimistic_lock_field_ident();

// 1) 原子 UPDATE(带 version CAS)
let res = models::Entity::update_many()
    #id_filters // 主键和 TenantId 过滤
    // ⬇️ 核心:只有当数据库中的版本号等于旧版本号时,才允许更新 ⬇️
    .filter(models::Column::#optimistic_lock_col_ident.eq(current_version)) 
    .set(update_model.clone())
    .exec(&conn)
    .await?;

if res.rows_affected > 0 {
    // 成功!说明版本匹配,且更新成功写入
    // ... 事件处理并返回 Ok(())
    return Ok(());
}

如果 rows_affected > 0,任务圆满完成。如果 rows_affected == 0,则进入下一步判断。

Step 2 & 3: 冲突检测与首次插入

如果 CAS 更新失败(rows_affected == 0),我们需要区分是 版本冲突(记录存在但版本号不匹配)还是 首次插入(记录根本不存在)。

// 2) UPDATE 未命中,检查记录是否存在
if models::Entity::find()
    #id_filters // 仅按主键和 TenantId 查找
    .one(&conn)
    .await?
    .is_some()
{
    // 记录存在,但 Step 1 未命中 -> 乐观锁冲突!
    return Err(#crate_root::domain::RepositoryError::optimistic_lock_error(
        "Optimistic lock conflict: Version mismatch".to_string(),
    ));
}

// 3) 记录不存在,执行首次插入
let insert_model: models::Model = entity_model.clone().try_into()?;
let mut active_model = insert_model.into_active_model();
active_model.insert(&conn).await?;
// ... 事件处理并返回 Ok(())

Talk is cheap, show me the code

before_save

fn insert_active_model_behavior_impl(input: &mut ItemMod, model_config: &ModelConfig) {
  let Some((_, items)) = &mut input.content else {
      return;
  };

  let mut has_active_model_behavior = false;
  for item in items.iter_mut() {
      if let syn::Item::Impl(item_impl) = item
          && let Some((_, path, _)) = &item_impl.trait_
          && path.segments.last().unwrap().ident == "ActiveModelBehavior"
      {
          has_active_model_behavior = true;
          break;
      }
  }

  if !has_active_model_behavior {
      let active_model_behavior_impl = quote! {
          #[async_trait]
          impl ActiveModelBehavior for ActiveModel {
              async fn before_save<C>(mut self, db: &C, insert: bool) -> Result<Self, DbErr>
              where
                  C: ConnectionTrait,
              {
                  Ok(self)
              }
          }
      };
      items.push(parse_quote!(#active_model_behavior_impl));
  }

  for item in items.iter_mut() {
      if let syn::Item::Impl(item_impl) = item
          && let Some((_, path, _)) = &item_impl.trait_
          && path.segments.last().unwrap().ident == "ActiveModelBehavior"
      {
          let mut has_before_save = false;
          for item in item_impl.items.iter_mut() {
              if let syn::ImplItem::Fn(method) = item
                  && method.sig.ident == "before_save"
              {
                  has_before_save = true;
                  break;
              }
          }

          if !has_before_save {
              let before_save_method = quote! {
                  async fn before_save<C>(mut self, db: &C, insert: bool) -> Result<Self, DbErr>
                  where
                      C: ConnectionTrait,
                  {
                      Ok(self)
                  }
              };
              item_impl.items.push(parse_quote!(#before_save_method));
          }

          let need_created_at = model_config
              .fields
              .iter()
              .any(|f| f.ident.as_ref().unwrap() == "created_at");
          let need_updated_at = model_config
              .fields
              .iter()
              .any(|f| f.ident.as_ref().unwrap() == "updated_at");

          let need_version = model_config
              .fields
              .iter()
              .any(|f| f.ident.as_ref().unwrap() == "version");

          if !(need_created_at || need_updated_at || need_version) {
              return;
          }

          for item in item_impl.items.iter_mut() {
              if let syn::ImplItem::Fn(method) = item
                  && method.sig.ident == "before_save"
              {
                  let mut stmts = Vec::new();
                  stmts.push(quote! {
                      let now = chrono::Utc::now();
                  });

                  if need_created_at {
                      let created_at_stmt = quote! {
                          if insert {
                              self.created_at = Set(now);
                          }
                      };
                      stmts.push(created_at_stmt);
                  }
                  if need_updated_at {
                      let updated_at_stmt = quote! {
                          if insert {
                              self.updated_at = Set(now);
                          } else if self.is_changed() {
                              self.updated_at = Set(now);
                          }
                      };
                      stmts.push(updated_at_stmt);
                  }

                  if need_version {
                      let version_stmt = quote! {
                          if insert {
                              self.version = Set(1);
                          } else if self.is_changed() {
                              let current_version = match self.version {
                              Set(v) => *v,
                              _ => 0,
                          };
                              self.version = Set(current_version + 1);
                          }
                      };
                      stmts.push(version_stmt);
                  }

                  let stmts = parse_quote!({#(#stmts)*});

                  // 插入到方法体的开头
                  method.block.stmts.insert(0, stmts);
              }
          }
      }
  }
}

宏生成的代码示例

 impl ActiveModelBehavior for ActiveModel {
        #[allow(
            elided_named_lifetimes,
            clippy::async_yields_async,
            clippy::diverging_sub_expression,
            clippy::let_unit_value,
            clippy::needless_arbitrary_self_type,
            clippy::no_effect_underscore_binding,
            clippy::shadow_same,
            clippy::type_complexity,
            clippy::type_repetition_in_bounds,
            clippy::used_underscore_binding
        )]
        fn before_save<'life0, 'async_trait, C>(
            self,
            db: &'life0 C,
            insert: bool,
        ) -> ::core::pin::Pin<
            Box<
                dyn ::core::future::Future<Output = Result<Self, DbErr>>
                    + ::core::marker::Send
                    + 'async_trait,
            >,
        >
        where
            C: ConnectionTrait,
            C: 'async_trait,
            'life0: 'async_trait,
            Self: 'async_trait,
        {
            Box::pin(async move {
                if let ::core::option::Option::Some(__ret) =
                    ::core::option::Option::None::<Result<Self, DbErr>>
                {
                    #[allow(unreachable_code)]
                    return __ret;
                }
                let mut __self = self;
                let insert = insert;
                let __ret: Result<Self, DbErr> = {
                    {
                        let now = chrono::Utc::now();
                        if insert {
                            __self.created_at = Set(now);
                        }
                        if insert {
                            __self.updated_at = Set(now);
                        } else if __self.is_changed() {
                            __self.updated_at = Set(now);
                        }
                    }
                    Ok(__self)
                };
                #[allow(unreachable_code)]
                __ret
            })
        }
    }

Repository::save

fn create_tenant_save_impl(
    crate_root: &Path,
    aggregate: &Path,
    args: &RepositoryStructArgs,
    id_filters: &TokenStream,
) -> TokenStream {
    // 若指定了乐观锁字段,准备字段名/Column ident
    let optimistic_lock_field = args.optimistic_lock_field.as_ref().map(|lit| {
        let optimistic_lock_field_name = lit.value();
        let optimstic_lock_field_ident = new_id(&optimistic_lock_field_name); // 用于 ActiveModel/Model 字段访问
        let optimistic_lock_col_ident = new_id(&to_pascal_case(&optimistic_lock_field_name)); // 用于 models::Column::Xxx
        (
            optimistic_lock_field_name,
            optimstic_lock_field_ident,
            optimistic_lock_col_ident,
        )
    });

    // 根据是否指定乐观锁字段,生成 save 的实现
    if let Some((
        optimistic_lock_field_name,
        optimistic_lock_field_ident,
        optimistic_lock_col_ident,
    )) = optimistic_lock_field
    {
        if optimistic_lock_field_name == "version" {
            quote! {
                async fn save(
                    &self,
                    txn: &mut TC,
                    entity: &mut #crate_root::domain::EventSourcedEntity<#aggregate>,
                ) -> Result<(), #crate_root::domain::RepositoryError> {
                    use #crate_root::domain::SeaOrmModelUpdater;
                    use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
                    use sea_orm::ActiveValue::Set;

                    let conn = txn.get_connection();
                    let entity_model: &#aggregate = entity;

                    let id = entity_model.id();
                    let tenant_id = entity_model.tenant_id();

                    // 从聚合获取当前版本与期望旧版本
                    let current_version = entity_model.#optimistic_lock_field_ident();

                    // 构造用于原子更新的 ActiveModel(只写回必要列)
                    let mut update_model = models::Model::from(entity_model.clone()).into_active_model();

                    // 1) 原子 UPDATE(带 version CAS)
                    let res = models::Entity::update_many()
                        #id_filters
                        .filter(models::Column::TenantId.eq(*tenant_id))
                        .filter(models::Column::#optimistic_lock_col_ident.eq(current_version))
                        .set(update_model.clone())
                        .exec(&conn)
                        .await?;

                    if res.rows_affected > 0 {
                        entity.move_event_to_context(txn);
                        return Ok(());
                    }

                    // 2) UPDATE 未命中,检查记录是否存在(按主键 + tenant)
                    if models::Entity::find()
                        #id_filters
                        .filter(models::Column::TenantId.eq(*tenant_id))
                        .one(&conn)
                        .await?
                        .is_some()
                    {
                        return Err(#crate_root::domain::RepositoryError::optimistic_lock_error(
                            "Optimistic lock conflict: Version mismatch".to_string(),
                        ));
                    }

                    // 3) 记录不存在,插入数据
                    let insert_model: models::Model = entity_model.clone().try_into()?;
                    let mut active_model = insert_model.into_active_model();
                    active_model.insert(&conn).await?;
                    entity.move_event_to_context(txn);
                    Ok(())
                }
            }
        } else {
            // treat as timestamp update_at
            quote! {
                async fn save(
                    &self,
                    txn: &mut TC,
                    entity: &mut #crate_root::domain::EventSourcedEntity<#aggregate>,
                ) -> Result<(), #crate_root::domain::RepositoryError> {
                    use #crate_root::domain::SeaOrmModelUpdater;
                    use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
                    use sea_orm::ActiveValue::Set;
                    use chrono::Utc;

                    let conn = txn.get_connection();
                    let entity_model: &#aggregate = entity;

                    let id = entity_model.id();
                    let tenant_id = entity_model.tenant_id();

                    // 读取实体携带的旧时间戳与准备新的时间戳
                    let current_ts = entity_model.#optimistic_lock_field_ident();

                    // 构造用于原子更新的 ActiveModel
                    let mut update_model = models::Model::from(entity_model.clone()).into_active_model();

                    // 1) 原子 UPDATE(带 updated_at CAS)
                    let res = models::Entity::update_many()
                        #id_filters
                        .filter(models::Column::TenantId.eq(*tenant_id))
                        .filter(models::Column::#optimistic_lock_col_ident.eq(current_ts))
                        .set(update_model.clone())
                        .exec(&conn)
                        .await?;

                    if res.rows_affected > 0 {
                        entity.move_event_to_context(txn);
                        return Ok(());
                    }

                    // 2) UPDATE 未命中,检查记录是否存在
                    if models::Entity::find()
                        #id_filters
                        .filter(models::Column::TenantId.eq(*tenant_id))
                        .one(&conn)
                        .await?
                        .is_some()
                    {
                        return Err(#crate_root::domain::RepositoryError::optimistic_lock_error(
                            "Optimistic lock conflict".to_string(),
                        ));
                    }

                    // 3) 记录不存在,直接插入数据
                    let insert_model: models::Model = entity_model.clone().try_into()?;
                    let mut active_model = insert_model.into_active_model();

                    active_model.insert(&conn).await?;
                    entity.move_event_to_context(txn);
                    Ok(())
                }
            }
        }
    } else {
        // no optimistic lock field -> simple update/insert behavior (原始实现)
        quote! {
            async fn save(
                &self,
                txn: &mut TC,
                entity: &mut #crate_root::domain::EventSourcedEntity<#aggregate>,
            ) -> Result<(), #crate_root::domain::RepositoryError> {
                use #crate_root::domain::SeaOrmModelUpdater;
                use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};

                let conn = txn.get_connection();

                let entity_model: &#aggregate = entity;

                let id = entity_model.id();
                let tenant_id = entity_model.tenant_id();

                if let Some(mut model) = models::Entity::find()
                    #id_filters
                    .filter(models::Column::TenantId.eq(*tenant_id))
                    .one(&conn)
                    .await?
                {
                    if &model.tenant_id != tenant_id {
                        return Err(#crate_root::domain::RepositoryError::mapping_error(
                            format!(
                                "Tenant ID mismatch: expected {}, found {}, id: {}",
                                tenant_id, model.tenant_id, id
                            ),
                        ));
                    }

                    // 更新逻辑
                    model.update_from_aggregate_root(entity_model).await?;

                    let active_model = model.into_active_model();
                    active_model.update(&conn).await?;
                } else {
                    // 创建新记录
                    let model: models::Model = entity_model.clone().try_into()?;
                    let active_model = model.into_active_model();
                    active_model.insert(&conn).await?;
                }

                entity.move_event_to_context(txn);
                Ok(())
            }
        }
    }
}

宏生成的代码示例

  async fn save(
        &self,
        txn: &mut TC,
        entity: &mut core_common::domain::EventSourcedEntity<TenantUser>,
    ) -> Result<(), core_common::domain::RepositoryError> {
        use core_common::domain::SeaOrmModelUpdater;
        use sea_orm::{
            ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter,
        };
        use sea_orm::ActiveValue::Set;
        use chrono::Utc;
        let conn = txn.get_connection();
        let entity_model: &TenantUser = entity;
        let id = entity_model.id();
        let current_ts = entity_model.update_at();
        let mut update_model = models::Model::from(entity_model.clone())
            .into_active_model();
        let res = models::Entity::update_many()
            .filter(models::Column::Id.eq((id.tenant_id(), id.user_id())))
            .filter(models::Column::UpdateAt.eq(current_ts))
            .set(update_model.clone())
            .exec(&conn)
            .await?;
        if res.rows_affected > 0 {
            entity.move_event_to_context(txn);
            return Ok(());
        }
        if models::Entity::find()
            .filter(models::Column::Id.eq((id.tenant_id(), id.user_id())))
            .one(&conn)
            .await?
            .is_some()
        {
            return Err(
                core_common::domain::RepositoryError::optimistic_lock_error(
                    "Optimistic lock conflict".to_string(),
                ),
            );
        }
        let insert_model: models::Model = entity_model.clone().try_into()?;
        let mut active_model = insert_model.into_active_model();
        active_model.insert(&conn).await?;
        entity.move_event_to_context(txn);
        Ok(())
    }

兼容性处理

宏的另一个优势是其灵活性。它能根据字段名称自动适配不同的乐观锁策略:

  • 如果检测到字段为 "version",则执行版本号的 CAS 逻辑。
  • 如果检测到其他时间戳字段如 "updated_at",则执行基于时间戳的 CAS 逻辑。

结论

通过将 before_save 中的版本递增逻辑,与 Repository::save 中的原子 CAS 检查完美结合,我们使用 Rust 过程宏实现了一个 高内聚、低耦合 的乐观锁基础设施。

开发者现在可以专注于业务逻辑,而将并发控制的复杂性和样板代码完全交给宏来处理。这不仅极大地提高了开发效率,同时也确保了底层持久化操作的健壮性和一致性。

评论区

写评论

还没有评论

1 共 0 条评论, 1 页