< 返回版块

allwefantasy 发表于 2019-04-21 11:50

Tags:Rust,C,FFI,Struct,Pointer

Suppose we have to structs:

#[repr(C)]
pub struct CTensor {
    data: *const c_float,
    data_length: c_int,
    shape: *const c_int,
    shape_length: c_int,
}

#[repr(C)]
pub struct CTensorArray {
    data: *const CTensor,
    len: c_int,
}

and we provide two methods to create them:

#[no_mangle]
pub extern "C" fn create_tensor(data: *const c_float,
                                data_length: c_int,
                                shape: *const c_int,
                                shape_length: c_int, ) -> *mut CTensor {
    let ctensor = CTensor {
        data,
        data_length,
        shape,
        shape_length,
    };

    println!("create_tensor data: {:?} len:{:?}", unsafe { *ctensor.data }, ctensor.data_length);
    println!("create_tensor shape: {:?} len:{:?}", unsafe { *ctensor.shape }, ctensor.shape_length);

    Box::into_raw(Box::new(ctensor))
}

#[no_mangle]
pub extern "C" fn create_tensor_array(data: *const CTensor,
                                      len: c_int) -> *mut CTensorArray {
    assert!(!data.is_null());
    let tensor_array = CTensorArray {
        data,
        len,
    };
    Box::into_raw(Box::new(tensor_array))
}

Here are the corresponding c struct and function signatures:

typedef struct CTensor {
    const float *data;
    int32_t data_length;
    const int *shape;
    int32_t shape_length;
} CTensor;

typedef struct CTensorArray {
    CTensor *data;
    int32_t len;
} CTensorArray;

CTensor *create_tensor(float *data, int32_t data_length, int32_t *shape, int32_t shape_length);

CTensorArray *create_tensor_array(CTensor *data, int32_t len);

The code in C side are like this:

CTensor *xTensor = create_tensor(xP, 1, shape_x_p, 1);
CTensor *yTensor = create_tensor(yP, 1, shape_y_p, 1);

CTensor *xy[] = {xTensor, yTensor};

CTensor *xy_p;
xy_p = xy;

CTensorArray *tarray = create_tensor_array(xy_p, 2);

//finally the tarray will be passed to predict
RawTensor *wow = predict(pre, "y_hat", "x,y", tarray);

here are the predict code:

#[no_mangle]
pub extern "C" fn predict(predictor: *mut Predictor, output_name: FfiStr, input_names: FfiStr, input_values: *mut CTensorArray) -> *mut Tensor<f32> {
    let r_predictor = unsafe {
        assert!(!predictor.is_null());
        *Box::from_raw(predictor)
    };

    let r_input_names = input_names.as_str().split(",").collect::<Vec<&str>>();

    assert!(!input_values.is_null());

    let input_values_ref = unsafe {
        *Box::from_raw(input_values)
    };

    let input_values_c_tensors = unsafe {
        slice::from_raw_parts(input_values_ref.data, input_values_ref.len as usize)
    };

    let mut r_input_values = Vec::new();

    for item in input_values_c_tensors.iter() {
        // here will panic
        r_input_values.push(CTensor::to(item))
    }

    let mut r_input_values_with_ref = Vec::new();
    for item in r_input_values.iter_mut() {
        r_input_values_with_ref.push(item)
    }


    let r_output_name = output_name.as_str();
    let tensor = r_predictor.predict(r_output_name, r_input_names, r_input_values_with_ref);

    Box::into_raw(Box::new(tensor))
}

CTensor::to code:

 fn to(ctensor: &CTensor) -> Tensor<f32> {
        println!("to_ctensor data: {:?} len:{:?}", unsafe { *ctensor.data }, ctensor.data_length);
        println!("to_ctensor shape: {:?} len:{:?}", unsafe { *ctensor.shape }, ctensor.shape_length);

        let mut shape_vec = Vec::new();

        let shape_slice = unsafe { slice::from_raw_parts(ctensor.shape, ctensor.shape_length as usize) };


        for &item in shape_slice.iter() {
            shape_vec.push(item as u64)
        }


        let shape = shape_vec.as_slice();

        let data = unsafe {
            slice::from_raw_parts(ctensor.data, ctensor.data_length as usize)
        };

        println!("shape_slice: {:?} shape: {:?} data: {:?}", shape_slice, shape_vec, data);

        let tensor = Tensor::new(shape).with_values(data).unwrap();
        tensor
    }

The question is, we can not restore Ctensor from CTensorArray again. Are there i'am missing?

评论区

写评论

还没有评论

1 共 0 条评论, 1 页