I am trying to learn about rayon's thread pool. And to that effect I want to make a simple program where a set of tasks is added to the pool, and then when the last task in that batch executes it adds more tasks.
To that effect I made this:
use rayon;
use std::sync::atomic::AtomicU32;
use std::sync::{Arc, Mutex};
fn main() {
let thread_num = std::thread::available_parallelism().unwrap().into();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(thread_num)
.exit_handler(|i| {
println!("Thread {} exited", i);
})
.build()
.unwrap();
let pool = Arc::new(pool);
let counter = Arc::new(AtomicU32::new(thread_num as u32));
let test = Arc::new(Mutex::new(String::default()));
for i in 0..thread_num {
let _pool = pool.clone();
let _counter = counter.clone();
let _test = test.clone();
pool.spawn(move || {
let val = _counter.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
_test
.lock()
.unwrap()
.push_str(format!("b {}\n", val).as_str());
if val == 1 {
_counter.store(thread_num as u32, std::sync::atomic::Ordering::SeqCst);
for i in 0..thread_num {
let _test = _test.clone();
_pool.spawn(move || {
_test.lock().unwrap().push_str("aaaa\n");
});
}
}
});
}
println!("{}", test.lock().unwrap().as_str());
}
Which is printing things like this:
b 12
b 11
b 10
b 9
b 8
b 7
The actual output changes between invocations, sometimes "aaaa" are present, sometimes not.
So A) I am printing before all the tasks finish, somehow, and B) it doesn;t seem like the thread exit callback is getting called.
What am I doing wrong?
I am trying to learn about rayon's thread pool. And to that effect I want to make a simple program where a set of tasks is added to the pool, and then when the last task in that batch executes it adds more tasks.
To that effect I made this:
use rayon;
use std::sync::atomic::AtomicU32;
use std::sync::{Arc, Mutex};
fn main() {
let thread_num = std::thread::available_parallelism().unwrap().into();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(thread_num)
.exit_handler(|i| {
println!("Thread {} exited", i);
})
.build()
.unwrap();
let pool = Arc::new(pool);
let counter = Arc::new(AtomicU32::new(thread_num as u32));
let test = Arc::new(Mutex::new(String::default()));
for i in 0..thread_num {
let _pool = pool.clone();
let _counter = counter.clone();
let _test = test.clone();
pool.spawn(move || {
let val = _counter.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
_test
.lock()
.unwrap()
.push_str(format!("b {}\n", val).as_str());
if val == 1 {
_counter.store(thread_num as u32, std::sync::atomic::Ordering::SeqCst);
for i in 0..thread_num {
let _test = _test.clone();
_pool.spawn(move || {
_test.lock().unwrap().push_str("aaaa\n");
});
}
}
});
}
println!("{}", test.lock().unwrap().as_str());
}
Which is printing things like this:
b 12
b 11
b 10
b 9
b 8
b 7
The actual output changes between invocations, sometimes "aaaa" are present, sometimes not.
So A) I am printing before all the tasks finish, somehow, and B) it doesn;t seem like the thread exit callback is getting called.
What am I doing wrong?
Share Improve this question asked Mar 16 at 19:59 MakoganMakogan 9,69010 gold badges64 silver badges159 bronze badges 7 | Show 2 more comments1 Answer
Reset to default 1Simply using rayon::spawn
will not wait for the tasks to be finished; when your main ends, the tasks are simply discarded.
If you want fork-join parallelism, you probably want to use rayon::scope
instead. It enters a new closure after which all spawned tasks get waited for.
This has the side effect that now, you do not need to mess around with Arc
; as rayon knows the tasks will be finished, you can use normal references instead.
Like so:
use rayon;
use std::sync::Mutex;
use std::sync::atomic::AtomicU32;
fn main() {
let thread_num = std::thread::available_parallelism().unwrap().into();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(thread_num)
.exit_handler(|i| {
println!("Thread {} exited", i);
})
.build()
.unwrap();
let test = Mutex::new(String::default());
let counter = AtomicU32::new(thread_num as u32);
pool.scope(|scope| {
for _ in 0..thread_num {
scope.spawn(|scope| {
let val = counter.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
test.lock()
.unwrap()
.push_str(format!("b {}\n", val).as_str());
if val == 1 {
counter.store(thread_num as u32, std::sync::atomic::Ordering::SeqCst);
for _ in 0..thread_num {
scope.spawn(|_| {
test.lock().unwrap().push_str("aaaa\n");
});
}
}
});
}
});
println!("{}", test.lock().unwrap().as_str());
}
b 6
b 2
b 1
aaaa
b 4
aaaa
b 3
aaaa
aaaa
aaaa
b 5
aaaa
Thread 3 exited
Thread 1 exited
Thread 5 exited
Be aware that while this does wait for the spawns to run, it does not wait until all threads have executed their exit handlers, so you do not necessarily see all Thread {} exited
messages.
If you want the program to wait for the exit handlers as well, you might want to call build_scoped
instead of build
and then perform your work inside of the scope:
use rayon;
use std::sync::Mutex;
use std::sync::atomic::AtomicU32;
fn main() {
let thread_num = std::thread::available_parallelism().unwrap().into();
rayon::ThreadPoolBuilder::new()
.num_threads(thread_num)
.exit_handler(|i| {
println!("Thread {} exited", i);
})
.build_scoped(
|thread| thread.run(),
|pool| {
let test = Mutex::new(String::default());
let counter = AtomicU32::new(thread_num as u32);
pool.scope(|scope| {
for _ in 0..thread_num {
scope.spawn(|scope| {
let val = counter.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
test.lock()
.unwrap()
.push_str(format!("b {}\n", val).as_str());
if val == 1 {
counter
.store(thread_num as u32, std::sync::atomic::Ordering::SeqCst);
for _ in 0..thread_num {
scope.spawn(|_| {
test.lock().unwrap().push_str("aaaa\n");
});
}
}
});
}
});
println!("{}", test.lock().unwrap().as_str());
},
)
.unwrap();
}
b 6
b 3
b 5
b 1
aaaa
aaaa
b 4
b 2
aaaa
aaaa
aaaa
aaaa
Thread 0 exited
Thread 3 exited
Thread 2 exited
Thread 1 exited
Thread 4 exited
Thread 5 exited
That said, your problem description strongly looks like you are trying to schedule your work manually. I would advise against that; instead, I strongly recommend using rayon's parallel iterators. Efficient work scheduling is hard and a lot of effort went into this problem already in the parallel iterators.
sleep(Duration::from_secs(1))
it at least prints what you expect. – Finomnis Commented Mar 16 at 20:18