< 返回版块

Matheritasiv 发表于 2022-11-25 15:10

Tags:multi-thread

我在使用rust实现一个非递归版本的双调排序算法,以下是单线程版本的代码:

use rand::Rng;
use std::cmp::PartialOrd;

struct SortIndex {
    start: usize,
    end: usize,
    char_num: usize,
    mod_num: usize,
    stop: bool,
}

impl SortIndex {
    fn new(start: usize, end: usize, char_bit: u32, mod_bit: u32) -> Self {
        SortIndex {
            start, end,
            char_num: 1usize.wrapping_shl(char_bit),
            mod_num: 1usize.wrapping_shl(mod_bit).wrapping_sub(1),
            stop: start >= end,
        }
    }
}

impl Iterator for SortIndex {
    type Item = (usize, usize);
    fn next(&mut self) -> Option<(usize, usize)> {
        if self.stop { return None; }
        let out = self.start;
        let mut x = out.wrapping_add(1);
        let _ = (x ^ out) & self.char_num == 0 || x & self.mod_num == 0 || {
            x = x.wrapping_add(self.char_num);
            x & self.mod_num != 0
        } || {
            x = x.wrapping_add(self.char_num);
            true
        };
        if x >= self.end || x <= out {
            self.stop = true;
        } else {
            self.start = x;
        }
        return Some((out, out ^ self.char_num));
    }
}

fn bitonic_sort<T>(data: &mut [T], depth: u32)
where T: Copy + PartialOrd {
    if depth == 0 { return; }
    let n = 1usize << depth;
    for cnt in 1 .. depth + 1 {
        for i in (0 .. cnt).rev() {
            for (ind1, ind2) in SortIndex::new(0, n, i, cnt) {
                if data[ind2] < data[ind1] {
                    data.swap(ind1, ind2);
                }
            }
        }
    }
}

fn main() {
    let mut rng = rand::thread_rng();
    let n = 1000000;
    let mut nums = (0 .. 4096).map(|_| rng.gen_range(-n ..= n)).collect::<Vec<i32>>();
    bitonic_sort(&mut nums, 12);
    println!("{:?}", nums);
}

它的基本原理是由SortIndex迭代器产生元素下标,然后由排序线程对这些下标指向的元素进行比较以及作可能的对换操作。每轮迭代产生的这些下标都是两两不交的,因此每轮对换过程理论上可以并行进行,但用rust线程机制来实现,则需要多个线程同时借用nums的可变引用,这是不被允许的。用chunks_mut()nums切片并分派到各个线程对于这个算法来说也不可行,因为需要对换的下标可能会跨越相当大的范围。请问这个问题有没有什么比较好的、运行开销小的多线程解决方案?

评论区

写评论
作者 Matheritasiv 2022-12-01 16:33

我现在在对这个双调排序的递归版本实现并发,核心代码逻辑类似于这样:

fn bitonic_sort<T>(count: u32, data: &mut [T], rev: bool)
where T: Copy + PartialOrd + Send + 'static {
    if data.len() <= 1 { return; }
    let ind = bitonic_divide(data.len());
    let (data1, data2) = data.split_at_mut(ind);
    if count == 0 {
        bitonic_sort(0, data1, !rev);
        bitonic_sort(0, data2, rev);
    } else {
        let thread1 = spawn(move || bitonic_sort(count - 1, data1, !rev));
        let thread2 = spawn(move || bitonic_sort(count - 1, data2, rev));
        thread1.join().unwrap();
        thread2.join().unwrap();
    }
    bitonic_merge(data, rev);
}

但它会报错,因为编译器无法静态推断出data1data2bitonic_merge()那里生命周期已经结束了。目前我只能给出如下这种unsafe的解决方式:

fn bitonic_sort<T>(count: u32, data: &mut [T], rev: bool)
where T: Copy + PartialOrd + 'static {
    if data.len() <= 1 { return; }
    let len = data.len();
    let ind = bitonic_divide(len);
    if count == 0 {
        let (data1, data2) = data.split_at_mut(ind);
        bitonic_sort(0, data1, !rev);
        bitonic_sort(0, data2, rev);
    } else {
        let data = PtrWrapper::new(data.as_mut_ptr());
        let thread1 = spawn(move || bitonic_sort(count - 1, unsafe {
            std::slice::from_raw_parts_mut(*data, ind)
        }, !rev));
        let thread2 = spawn(move || bitonic_sort(count - 1, unsafe {
            std::slice::from_raw_parts_mut(data.add(ind), len - ind)
        }, rev));
        thread1.join().unwrap();
        thread2.join().unwrap();
    }
    bitonic_merge(data, rev);
}

请问有没有不用unsafe的方式来解决这个生命周期问题?

eweca-d 2022-11-30 20:27

我本来想说rayon里有par_iter,后来想想实现中估计会展开?你的swap操作耗时太短,这样应该是亏的。是我想岔了。

--
👇
Matheritasiv

作者 Matheritasiv 2022-11-30 11:19

这个实现中所有线程共用一个迭代器,从迭代器中读取值会改变内部状态,必须得用读写锁。我猜测是因为在线程的一轮循环里面,比较和交换数组元素用时比迭代器生成下标要短,这样它被挂起时很大概率是持有锁的,直到这个线程下一次被调度排序才会继续进行,整个过程中真正有多个线程在同时排序的时候很少。如果说对*data的读写操作是相对费时的,那这个实现相对于单线程版本应该是有速度提升的。

我改了一下SortIndex的实现,现在这个版本没问题了:

fn bitonic_p_sort<T>(data: &mut [T], t_depth: u32)
where T: Copy + PartialOrd + 'static {
    if data.len() <= 1 { return; }
    let n = data.len();
    let data = PtrWrapper::new(data.as_mut_ptr());
    let t_n = 1usize << t_depth;
    let depth = {
        let (mut depth, mut n) = (1u32, n - 1 >> 1);
        while n != 0 { n >>= 1; depth += 1; }
        depth
    };
    let chunk_depth = if depth > t_depth {
        depth.wrapping_sub(t_depth)
    } else { 0 };
    let chunk = 1usize << chunk_depth;
    let mut rev = depth & 1 == 0;
    for cnt in 1 .. depth + 1 {
        for i in (0 .. cnt).rev() {
            let mut thread_list = vec![];
            for j in 0 .. t_n {
                thread_list.push(spawn(move || {
                    for (ind1, ind2) in SortIndex::new(
                        j * chunk, (j + 1) * chunk, n, cnt, i, rev) {
                        let (ind1, ind2) = (ind1 as isize, ind2 as isize);
                        unsafe {
                            if *data.offset(ind2) < *data.offset(ind1) {
                                (*data.offset(ind1), *data.offset(ind2)) =
                                    (*data.offset(ind2), *data.offset(ind1));
                            }
                        }
                    }
                }));
            }
            for t in thread_list {
                t.join().unwrap();
            }
        }
        rev = !rev;
    }
}

运行时间是单线程版本的一半。

--
👇
eweca-d: 别用Mutex锁,Neutron3529的意思大概是让你直接unsafe直接操作裸指针吧。详见std::ptr模块。

--
👇
Matheritasiv: 大家帮忙看看这个实现有问题吗:

let shared_indexes = Arc::new(Mutex::new(SortIndex::new(0, n, i, cnt)));
for _ in 0 .. t_n {
    let shared_indexes = shared_indexes.clone();
    thread_list.push(spawn(move || {
        while let Some((ind1, ind2)) = shared_indexes.lock().unwrap().next() {
            let (ind1, ind2) = (ind1 as isize, ind2 as isize);
            unsafe {
                if *data.offset(ind2) < *data.offset(ind1) {
                    (*data.offset(ind1), *data.offset(ind2)) =
                        (*data.offset(ind2), *data.offset(ind1));
                }
            }
        }
    }));
}

我实测这个版本的运行时间大约是单线程版本的20倍,不管是开16个线程还是开2个线程。

eweca-d 2022-11-30 09:50

别用Mutex锁,Neutron3529的意思大概是让你直接unsafe直接操作裸指针吧。详见std::ptr模块。

--
👇
Matheritasiv: 大家帮忙看看这个实现有问题吗:

let shared_indexes = Arc::new(Mutex::new(SortIndex::new(0, n, i, cnt)));
for _ in 0 .. t_n {
    let shared_indexes = shared_indexes.clone();
    thread_list.push(spawn(move || {
        while let Some((ind1, ind2)) = shared_indexes.lock().unwrap().next() {
            let (ind1, ind2) = (ind1 as isize, ind2 as isize);
            unsafe {
                if *data.offset(ind2) < *data.offset(ind1) {
                    (*data.offset(ind1), *data.offset(ind2)) =
                        (*data.offset(ind2), *data.offset(ind1));
                }
            }
        }
    }));
}

我实测这个版本的运行时间大约是单线程版本的20倍,不管是开16个线程还是开2个线程。

作者 Matheritasiv 2022-11-29 15:20

大家帮忙看看这个实现有问题吗:

let shared_indexes = Arc::new(Mutex::new(SortIndex::new(0, n, i, cnt)));
for _ in 0 .. t_n {
    let shared_indexes = shared_indexes.clone();
    thread_list.push(spawn(move || {
        while let Some((ind1, ind2)) = shared_indexes.lock().unwrap().next() {
            let (ind1, ind2) = (ind1 as isize, ind2 as isize);
            unsafe {
                if *data.offset(ind2) < *data.offset(ind1) {
                    (*data.offset(ind1), *data.offset(ind2)) =
                        (*data.offset(ind2), *data.offset(ind1));
                }
            }
        }
    }));
}

我实测这个版本的运行时间大约是单线程版本的20倍,不管是开16个线程还是开2个线程。

Neutron3529 2022-11-25 18:16

直接unsafe

unsafe就是让你处理多个可变借用的

你直接unsafe解引用裸指针就好

1 共 6 条评论, 1 页