< 返回我的博客

blossom-001 发表于 2025-11-18 20:31

在 Rust 错误处理中,我们经常需要记录错误发生的位置信息以便调试。虽然 snafu 库提供了强大的错误处理能力,但手动为每个错误变体添加位置字段和工厂方法仍然繁琐且容易出错。本文介绍一个自定义的过程宏 #[with_err_location],它可以自动化这些重复工作,让错误处理更加优雅和高效。

问题背景

使用 snafu 进行错误处理时,我们通常需要:

  1. 为每个错误变体手动添加 location 字段
  2. 添加相应的属性(#[snafu(implicit)]#[serde(skip)]
  3. 为复杂的 source 字段添加 #[snafu(source(false))]
  4. 手动实现工厂方法来创建错误实例

这导致了大量的样板代码:

#[derive(Debug, Serialize, Snafu)]
#[serde(tag = "type")]
pub enum ApiError {
    #[serde(rename = "validate_error")]
    ValidateError {
        message: String,
        #[serde(skip)]
        #[snafu(implicit)]
        location: snafu::Location,
    },
    
    #[serde(rename = "internal_error")]
    InternalError {
        message: String,
        #[serde(skip)]
        #[snafu(source(false))]
        source: Option<Box<dyn std::error::Error + Send + Sync>>,
        #[serde(skip)]
        #[snafu(implicit)]
        location: snafu::Location,
    },
}

impl ApiError {
    #[track_caller]
    pub fn validate_error(message: String) -> Self {
        ApiError::ValidateError {
            message,
            location: GenerateImplicitData::generate(),
        }
    }
    
    #[track_caller]
    pub fn internal_error(message: String) -> Self {
        ApiError::InternalError {
            message,
            source: None,
            location: GenerateImplicitData::generate(),
        }
    }
    
    #[track_caller]
    pub fn internal_error_with_source(message: String, source: Option<Box<dyn std::error::Error + Send + Sync>>) -> Self {
        ApiError::InternalError {
            message,
            source,
            location: GenerateImplicitData::generate(),
        }
    }
}

解决方案:#[with_err_location]

#[with_err_location] 宏可以自动化所有这些工作,让您只需要定义核心的错误结构:

#[with_err_location]
#[derive(Debug, Serialize, Snafu)]
#[serde(tag = "type")]
pub enum ApiError {
    #[serde(rename = "validate_error")]
    ValidateError {
        message: String,
    },
    
    #[serde(rename = "internal_error")]
    InternalError {
        message: String,
        source: Option<Box<dyn std::error::Error + Send + Sync>>,
    },
}

核心特性

1. 自动添加 Location 字段

宏会为每个枚举变体自动添加 location: snafu::Location 字段,并配置必要的属性:

  • #[snafu(implicit)]:让 snafu 自动填充位置信息
  • #[serde(skip)]:在序列化时跳过该字段(默认行为)

2. 智能 Source 字段处理

宏能识别复杂的 source 字段类型,并自动添加 #[snafu(source(false))] 属性:

// 自动识别并处理
source: Option<Box<dyn std::error::Error + Send + Sync>>

3. 自动生成工厂方法

宏为每个变体生成相应的工厂方法:

普通变体

// 生成:
pub fn validate_error(message: String) -> Self { ... }

复杂 Source 字段变体

对于包含 Option<Box<dyn Error + Send + Sync>> 类型的 source 字段,宏会生成两个方法:

// 基础方法(source = None)
pub fn internal_error(message: String) -> Self { ... }

// 带 source 的方法
pub fn internal_error_with_source(message: String, source: Option<Box<dyn std::error::Error + Send + Sync>>) -> Self { ... }

4. 灵活的配置选项

全局配置

#[with_err_location(serde = true)]  // 不添加 #[serde(skip)]
#[derive(Debug, Snafu)]
pub enum ApiError { ... }

变体级别配置

#[with_err_location]
#[derive(Debug, Snafu)]
pub enum ApiError {
    #[location(serde = true)]  // 此变体不添加 #[serde(skip)]
    SpecialError {
        message: String,
    },
}

实现细节

宏的工作流程

  1. 解析输入:解析枚举定义和宏参数
  2. 字段分析:检查每个变体的字段类型和现有属性
  3. 添加 Location 字段:为没有 location 字段的变体添加
  4. 属性处理:添加必要的 snafu 和 serde 属性
  5. 工厂方法生成:基于字段类型生成相应的工厂方法

关键函数

字段类型检测

fn should_add_source_false(field: &syn::Field) -> bool {
    let type_str = field.ty.to_token_stream().to_string();
    let is_option_box_dyn_error = type_str.starts_with("Option < Box < dyn");
    let is_source_field = field.ident.as_ref().map(|name| name == "source").unwrap_or(false);
    is_source_field && is_option_box_dyn_error
}

工厂方法生成

fn generate_factory_methods(input_enum: &ItemEnum) -> darling::Result<TokenStream> {
    // 检测复杂 source 字段
    let has_complex_source = fields_named.named.iter().any(should_add_source_false);
    
    if has_complex_source {
        // 生成两个方法:基础方法和带 source 的方法
    } else {
        // 生成单个方法
    }
}

使用示例

基本使用

#[with_err_location]
#[derive(Debug, Snafu)]
pub enum MyError {
    NetworkError { url: String },
    ValidationError { field: String, message: String },
}

// 使用生成的工厂方法
let error = MyError::network_error("https://api.example.com".to_string());

复杂 Source 字段

#[with_err_location]
#[derive(Debug, Snafu)]
pub enum ComplexError {
    DatabaseError {
        query: String,
        source: Option<Box<dyn std::error::Error + Send + Sync>>,
    },
}

// 两种使用方式
let error1 = ComplexError::database_error("SELECT * FROM users".to_string());
let error2 = ComplexError::database_error_with_source(
    "SELECT * FROM users".to_string(),
    Some(Box::new(io_error))
);

配置选项

#[with_err_location(serde = true)]  // 全局配置
#[derive(Debug, Snafu)]
pub enum ApiError {
    #[location(serde = false)]  // 变体级别覆盖
    InternalError { message: String },
    
    PublicError { message: String },  // 使用全局配置
}

完整代码

#[proc_macro_attribute]
pub fn with_err_location(
    args: proc_macro::TokenStream,
    input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let args = args.into();
    with_err_location::with_err_location_impl(args, input.into())
        .unwrap_or_else(darling::Error::write_errors)
        .into()
} 
use darling::{Error, FromMeta, ast::NestedMeta};
use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::{Attribute, Field, Fields, ItemEnum, Meta, punctuated::Punctuated, token::Comma};

#[derive(Debug, FromMeta, Default)]
struct WithErrLocationArgs {
    pub serde: bool,
}

pub fn with_err_location_impl(
    args: TokenStream,
    input: TokenStream,
) -> darling::Result<TokenStream> {
    let mut input_enum: ItemEnum = match syn::parse2(input) {
        Ok(v) => v,
        Err(e) => return Err(Error::from(e)),
    };

    // 解析全局参数
    let global_args = if args.is_empty() {
        WithErrLocationArgs::default()
    } else {
        let attr_args = match NestedMeta::parse_meta_list(args) {
            Ok(v) => v,
            Err(e) => return Err(Error::from(e)),
        };
        WithErrLocationArgs::from_list(&attr_args).unwrap_or_default()
    };

    // 遍历枚举的所有变体
    for variant in &mut input_enum.variants {
        // 查找并解析 #[location(...)] 属性
        let (location_config, remaining_attrs) =
            parse_and_remove_location_attrs(&variant.attrs, &global_args)?;

        // 移除 location 属性,保留其他属性
        variant.attrs = remaining_attrs;

        match &mut variant.fields {
            Fields::Named(fields_named) => {
                // 检查是否已经有 location 字段
                let location_field_index = fields_named.named.iter().position(|field| {
                    field
                        .ident
                        .as_ref()
                        .map(|ident| ident == "location")
                        .unwrap_or(false)
                });
                match location_field_index {
                    Some(index) => {
                        // 如果已经有 location 字段,确保它至少有 #[snafu(implicit)]
                        let existing_field = &mut fields_named.named[index];
                        ensure_location_field_has_snafu_implicit(existing_field, &location_config);
                    }
                    None => {
                        // 如果没有 location 字段,则添加一个新的(总是带有 #[snafu(implicit)])
                        let location_field = create_location_field(&location_config);
                        fields_named.named.push(location_field);
                        fields_named.named.push_punct(Comma::default());
                    }
                }
                // 如果有source 且类型是Option<Box<dyn std::error::Error + Send + Sync>>
                // 需要为其加上#[snafu(source(false))]
                for field in &mut fields_named.named {
                    if should_add_source_false(field) {
                        ensure_source_false_attribute(field);
                    }
                }
            }
            _ => {
                return Err(Error::unsupported_format(
                    "Only named fields variants are supported",
                ));
            }
        }
    }

    // 生成工厂方法
    let factory_methods = generate_factory_methods(&input_enum)?;

    Ok(quote! {
        #input_enum
        #factory_methods
    })
}

/// 解析并移除 location 属性,返回配置和剩余属性
fn parse_and_remove_location_attrs(
    variant_attrs: &[Attribute],
    global_args: &WithErrLocationArgs,
) -> darling::Result<(LocationConfig, Vec<Attribute>)> {
    let mut config = LocationConfig {
        serde: global_args.serde,
    };

    let mut remaining_attrs = Vec::new();

    for attr in variant_attrs {
        if attr.path().is_ident("location") {
            // 解析 location 属性的参数
            match &attr.meta {
                Meta::List(meta_list) => {
                    let nested = meta_list.parse_args_with(
                        Punctuated::<NestedMeta, syn::Token![,]>::parse_terminated,
                    )?;
                    let location_args =
                        WithErrLocationArgs::from_list(&nested.into_iter().collect::<Vec<_>>())?;

                    config.serde = location_args.serde;
                }
                _ => {
                    // 如果没有参数,使用默认配置
                }
            }
        } else {
            // 保留非 location 属性
            remaining_attrs.push(attr.clone());
        }
    }

    Ok((config, remaining_attrs))
}

#[derive(Debug)]
struct LocationConfig {
    serde: bool,
}

/// 确保现有的 location 字段至少有 #[snafu(implicit)] 属性
fn ensure_location_field_has_snafu_implicit(field: &mut Field, config: &LocationConfig) {
    // 根据配置添加或确保有 #[serde(skip)]
    if !config.serde {
        let has_serde_skip = field.attrs.iter().any(|attr| {
            if attr.path().is_ident("serde")
                && let Meta::List(meta_list) = &attr.meta
            {
                return meta_list.tokens.to_string().contains("skip");
            }
            false
        });

        if !has_serde_skip {
            let serde_skip_attr: Attribute = syn::parse_quote! {
                #[serde(skip)]
            };
            field.attrs.push(serde_skip_attr);
        }
    }
    let has_snafu_implicit = field.attrs.iter().any(|attr| {
        if attr.path().is_ident("snafu")
            && let Meta::List(meta_list) = &attr.meta
        {
            return meta_list.tokens.to_string().contains("implicit");
        }
        false
    });

    // 如果没有 #[snafu(implicit)],则添加它
    if !has_snafu_implicit {
        let snafu_implicit_attr: Attribute = syn::parse_quote! {
            #[snafu(implicit)]
        };
        field.attrs.push(snafu_implicit_attr);
    }
}

/// 检查字段是否需要自动添加 #[snafu(source(false))]
fn should_add_source_false(field: &syn::Field) -> bool {
    let type_str = field.ty.to_token_stream().to_string();

    // 检查是否是 Option<Box<dyn std::error::Error + Send + Sync>> 类型
    let is_option_box_dyn_error = type_str.starts_with("Option < Box < dyn");

    // 检查字段名是否为 "source"
    let is_source_field = field
        .ident
        .as_ref()
        .map(|name| name == "source")
        .unwrap_or(false);

    is_source_field && is_option_box_dyn_error
}

/// 确保复杂 source 字段有 #[snafu(source(false))] 属性
fn ensure_source_false_attribute(field: &mut Field) {
    // 检查是否已经有 #[snafu(source(false))] 属性
    let has_source_false = field.attrs.iter().any(|attr| {
        if attr.path().is_ident("snafu")
            && let Meta::List(meta_list) = &attr.meta
        {
            let tokens_str = meta_list.tokens.to_string();
            return tokens_str.contains("source")
                && (tokens_str.contains("false") || tokens_str.contains("( false )"));
        }
        false
    });

    // 如果没有,则添加 #[snafu(source(false))]
    if !has_source_false {
        let source_false_attr: Attribute = syn::parse_quote! {
            #[snafu(source(false))]
        };
        field.attrs.push(source_false_attr);
    }
}

/// 根据配置创建 location 字段
fn create_location_field(config: &LocationConfig) -> Field {
    if !config.serde {
        syn::parse_quote! {
            #[serde(skip)]
            #[snafu(implicit)]
            location: snafu::Location
        }
    } else {
        syn::parse_quote! {
            #[snafu(implicit)]
            location: snafu::Location
        }
    }
}

/// 为枚举生成工厂方法
fn generate_factory_methods(input_enum: &ItemEnum) -> darling::Result<TokenStream> {
    let enum_name = &input_enum.ident;
    let mut methods = Vec::new();

    for variant in &input_enum.variants {
        let variant_name = &variant.ident;

        // 将变体名转换为 snake_case
        let method_name = convert_to_snake_case(&variant_name.to_string());
        let method_ident = syn::Ident::new(&method_name, variant_name.span());

        match &variant.fields {
            Fields::Named(fields_named) => {
                // 检查是否有复杂的 source 字段
                let has_complex_source = fields_named.named.iter().any(should_add_source_false);

                if has_complex_source {
                    // 生成两个方法:基础方法(source = None)和带 source 的方法

                    // 1. 基础方法:source 为 None
                    let (base_params, base_assignments) =
                        analyze_fields_for_source_method(fields_named, true);
                    let base_method = quote! {
                        #[track_caller]
                        pub fn #method_ident(#(#base_params),*) -> Self {
                            #enum_name::#variant_name {
                                #(#base_assignments,)*
                            }
                        }
                    };
                    methods.push(base_method);

                    // 2. 带 source 的方法
                    let source_method_name = format!("{}_with_source", method_name);
                    let source_method_ident =
                        syn::Ident::new(&source_method_name, variant_name.span());
                    let (source_params, source_assignments) =
                        analyze_fields_for_source_method(fields_named, false);

                    let source_method = quote! {
                        #[track_caller]
                        pub fn #source_method_ident(#(#source_params),*) -> Self
                        {
                            #enum_name::#variant_name {
                                #(#source_assignments,)*
                            }
                        }
                    };
                    methods.push(source_method);
                } else {
                    // 分析字段,确定需要的参数
                    let (params, field_assignments) = analyze_fields(fields_named);

                    // 生成基础方法
                    let method = quote! {
                        #[track_caller]
                        pub fn #method_ident(#(#params),*) -> Self {
                            #enum_name::#variant_name {
                                #(#field_assignments,)*
                            }
                        }
                    };

                    methods.push(method);
                }
            }
            _ => continue,
        }
    }

    Ok(quote! {
        impl #enum_name {
            #(#methods)*
        }
    })
}

/// 将 PascalCase 转换为 snake_case
fn convert_to_snake_case(s: &str) -> String {
    let mut result = String::new();
    for (i, ch) in s.chars().enumerate() {
        if ch.is_uppercase() && i > 0 {
            result.push('_');
        }
        result.push(ch.to_lowercase().next().unwrap());
    }
    result
}

/// 分析字段,生成参数和字段赋值
fn analyze_fields(fields: &syn::FieldsNamed) -> (Vec<TokenStream>, Vec<TokenStream>) {
    let mut params = Vec::new();
    let mut assignments = Vec::new();

    for field in &fields.named {
        let field_name = field.ident.as_ref().unwrap();
        let field_type = &field.ty;

        if field_name == "location" {
            assignments.push(quote! { #field_name: snafu::GenerateImplicitData::generate() });
            continue;
        }

        // 普通字段作为参数
        params.push(quote! { #field_name: #field_type });
        assignments.push(quote! { #field_name });
    }

    (params, assignments)
}

/// 分析字段,为带 source 的方法生成参数和字段赋值
fn analyze_fields_for_source_method(
    fields: &syn::FieldsNamed,
    is_base: bool,
) -> (Vec<TokenStream>, Vec<TokenStream>) {
    let mut params = Vec::new();
    let mut assignments = Vec::new();

    for field in &fields.named {
        let field_name = field.ident.as_ref().unwrap();
        let field_type = &field.ty;

        if field_name == "location" {
            assignments.push(quote! { #field_name: snafu::GenerateImplicitData::generate() });
            continue;
        }

        if is_base && should_add_source_false(field) {
            // 复杂 source 字段设为 None,不作为参数
            assignments.push(quote! { #field_name: None });
        } else {
            // 普通字段作为参数
            params.push(quote! { #field_name: #field_type });
            assignments.push(quote! { #field_name });
        }
    }

    (params, assignments)
}

优势总结

  1. 减少样板代码:自动生成重复的字段和方法定义
  2. 类型安全:在编译时确保正确的类型处理
  3. 灵活配置:支持全局和变体级别的配置选项
  4. 智能处理:自动识别复杂类型并生成相应的方法
  5. 向后兼容:可以与现有的 snafu 代码无缝集成

结论

#[with_err_location] 宏通过自动化错误处理中的重复工作,显著提升了开发效率和代码质量。它不仅减少了样板代码,还通过智能的类型检测和方法生成,提供了更加优雅和类型安全的错误处理解决方案。

无论是简单的错误类型还是复杂的带源错误的场景,这个宏都能提供恰到好处的自动化支持,让开发者能够专注于业务逻辑而不是重复的错误处理代码。

评论区

写评论

还没有评论

1 共 0 条评论, 1 页