< 返回版块

hzqd 发表于 2021-03-26 16:48

Rayon library already implemented the partition function, and it only divides a vector into two parts, but I need three parts. I code this:

use rayon::prelude::*;

trait IterExt<T> {
    fn for_each(self, f: impl Fn(T) + Sync + Send);
    fn partition3(self, predicate1: impl FnMut(&T) -> bool + Sync + Send, predicate2: impl FnMut(&T) -> bool + Sync + Send) -> (Vec<T>, Vec<T>, Vec<T>) where T: Sync;
}

impl<T> IterExt<T> for Vec<T> where T: Send {
    fn for_each(self, f: impl Fn(T) + Sync + Send) {
        self.into_par_iter().for_each(f);
    }

    fn partition3(self, mut predicate1: impl FnMut(&T) -> bool + Sync + Send, mut predicate2: impl FnMut(&T) -> bool + Sync + Send) -> (Vec<T>, Vec<T>, Vec<T>) where T: Sync {
        let mut first = vec![];
        let mut second = vec![];
        let mut third = vec![];
        self.for_each(|e|
            if predicate1(&e) { first.push(e) }
            else if predicate2(&e) { second.push(e) }
            else { third.push(e) }
        );
        (first, second, third)
    }
}

I tried to write for a long time, but still got stuck in borrow checker.

I want to know how to write a correct parallel function to divide a vector into three parts.

评论区

写评论
johnmave126 2021-03-27 10:04

I don't think you can declare predicate1 and predicate2 as FnMut. When rayon executes the partition parallelly, different threads may need to call predicate1 at the same time, which violates the borrow rule since it requires a mutable reference to predicate1 to call through FnMut.

What you should do is to make predicates Fn, and use interior mutability and proper synchronization if mutation is needed.

PS: use partition_map to partition into multiple collections.

use rayon::prelude::*;
use rayon::iter::Either;

trait IterExt<T> {
    fn for_each(self, f: impl Fn(T) + Sync + Send);
    fn partition3(self, predicate1: impl Fn(&T) -> bool + Sync + Send, predicate2: impl Fn(&T) -> bool + Sync + Send) -> (Vec<T>, Vec<T>, Vec<T>) where T: Sync;
}

impl<T> IterExt<T> for Vec<T> where T: Send {
    fn for_each(self, f: impl Fn(T) + Sync + Send) {
        self.into_par_iter().for_each(f);
    }

    fn partition3(self, predicate1: impl Fn(&T) -> bool + Sync + Send, predicate2: impl Fn(&T) -> bool + Sync + Send) -> (Vec<T>, Vec<T>, Vec<T>) where T: Sync {
        let ((first, second), third) = self.into_par_iter().partition_map(move |e| {
            if predicate1(&e) {Either::Left(Either::Left(e))}
            else if predicate2(&e) {Either::Left(Either::Right(e))}
            else{Either::Right(e)}
        });
        (first, second, third)
    }
}
1 共 1 条评论, 1 页