< 返回版块

eweca-d 发表于 2022-09-16 13:03

因为python循环太慢,想通过rust加快一下某个函数(自适应寻找峰值),有如下代码:

#[repr(C)]
pub struct VectorI64 {
    size: i64,
    raw_data: *const i64
}

impl VectorI64 {
    fn new(size: usize, raw_data: *const i64) -> Self {
        Self { size: size as i64, raw_data}
    }
}

fn vec_usize_to_vec_i64(data: Vec<usize>) -> Vec<i64> {
    data.iter().map(|&x| x as i64).collect()
}

#[no_mangle]
pub extern "C" fn ampd_rust(size: i64, data: *mut f64) -> *const VectorI64 {
    unsafe {
        let data = Vec::from_raw_parts(data, size as usize, size as usize);
        let pks = find_peaks::ampd(data);
        let pks = vec_usize_to_vec_i64(pks);
        println!("pks(i64) has been found, len is {}, content is {:?}", pks.len(), pks);
        let vec_i64 = Box::new(VectorI64::new(pks.len(), pks.as_ptr()));
        Box::leak(vec_i64) as *const VectorI64
    }
}
vector_i64_list = [("size", c_int64),
                   ("raw_data", POINTER(c_int64))]


class VectorI64(Structure):
    _fields_ = vector_i64_list


algorithms = cdll.LoadLibrary('algorithms.dll')


def AMPD_rust(data):
    data_size = len(data)

    ampd = algorithms.ampd_rust
    ampd.argtypes = [c_int64,
                     np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags="C_CONTIGUOUS")]
    ampd.restype = POINTER(VectorI64)
    res = ampd(data_size, data).contents
    ret = []
    print("the size of res vector is {}".format(res.size))
    for i in range(res.size):
        ret.append(res.raw_data[i])
    print("return vector is {}".format(ret))
    return ret

作用是传入一维numpy数组,然后返回峰值对应的index数组。但是很奇怪的是,明明在rust里打印出来的结果数组很正常(虽然直接把usize转i64,但这里的案例保证没有越界),但是返回python读取重建list之后,list里绝大部分的数值是正常的(正确的数值,正确的位置),某些数值变成了类似“2224283632912, 2224283206816,17740747716985036, 2224283632912, 2224283206816”,特别是前两个index,几乎一定会变,后面是有选择的变化(固定一次变三个连续的数,然后正常,然后变三个连续的数)。给我弄不会了。如果在python里打印“res.raw_data.contents”,得到第一个数也是类似“c_longlong(2224283632912)”这种形式。

这个是什么原因呢?求教!感谢。


更新:答案是需要forget pks数组。但是一个新的问题是,“Vec::from_raw_parts”这一步也有问题,因为“ *mut f64”的allocator是在python的numpy那里的,怎么避免在rust里free *mut f64却还能得到vec,还有待进一步的搜寻答案。目前的想法是直接forget掉data数组:

#[no_mangle]
pub extern "C" fn ampd_rust(size: i64, capacity: i64, data: *mut f64) -> *mut VectorI64 {
    unsafe {
        let data = Vec::from_raw_parts(data, size as usize, capacity as usize);
        let pks = find_peaks::ampd(&data);
        mem::forget(data);  // 直接forget掉data:Vec<f64>数组,避免double free的发生
        let mut pks = vec_usize_to_vec_i64(pks);
        let vec_i64 = Box::new(VectorI64::new(pks.len(), pks.capacity(), pks.as_mut_ptr()));
        mem::forget(pks);
        Box::leak(vec_i64) as *mut VectorI64
    }
}

但是这种做法不知道会不会有内存泄露,比如,*mut 64部分应该能被numpy给free掉,但是rust的Vec自带的size和capacity貌似不会被free掉吧?如果能在stable中使用into_raw_parts可能可以避免这个问题?


采取了一个或许更好的方式,就是使用slice来代替:

#[no_mangle]
pub extern "C" fn ampd_rust(size: i64, data: *mut f64) -> *mut VectorI64 {
    unsafe {
        let data_slice = std::slice::from_raw_parts(data, size as usize);
        let pks = find_peaks::ampd(data_slice);
        let mut pks = vec_usize_to_vec_i64(pks);
        pks.shrink_to_fit();
        let vec_i64 = Box::new(VectorI64::new(pks.len(), pks.capacity(), pks.as_mut_ptr()));
        mem::forget(pks);
        Box::leak(vec_i64) as *mut VectorI64
    }
}

这下应该没有forget,size和capacity的烦恼了。

评论区

写评论
作者 eweca-d 2022-09-16 14:03

特别在涉及多重循环的计算,还是得rust啊。

rust: n = 1000, elapsed time is 0.0012737000000000442s

python: n = 1000, elapsed time is 0.07528889999999999s

rust: n = 5000, elapsed time is 0.02186749999999993s

python: n = 5000, elapsed time is 1.6737611000000001s

rust: n = 10000, elapsed time is 0.08178969999999985s

python: n = 10000, elapsed time is 6.7801179s

作者 eweca-d 2022-09-16 13:56

哈哈,感谢给了我灵感。不过不是Pin的缘故,因为我这个分配是在堆上的。我看到Pin突然想起来,可能是内存被释放了的缘故被其他乱七八糟的东西写入了,然后发现,居然忘了forget Vec了。

改正如下:

#[no_mangle]
pub extern "C" fn ampd_rust(size: i64, data: *mut f64) -> *const VectorI64 {
    unsafe {
        let data = Vec::from_raw_parts(data, size as usize, size as usize);
        let pks = find_peaks::ampd(data);
        let pks = vec_usize_to_vec_i64(pks);
        println!("pks(i64) has been found, len is {}, content is {:?}", pks.len(), pks);
        let vec_i64 = Box::new(VectorI64::new(pks.len(), pks.as_ptr()));
        mem::forget(pks);
        Box::leak(vec_i64) as *const VectorI64
    }
}

现在可以了。

--
👇
yinheli: 不知道是不是要 Pin ?

yinheli 2022-09-16 13:33

不知道是不是要 Pin ?

1 共 3 条评论, 1 页