Skip to main content

Concurrency

Concurrency Model

NameProsCons
OS Threadsimple, native modelconsistent and context switch overhead
Event Drivenperf modelnon-liner logic, callback hell
Coroutinesperf modelnon-system abstraction
Actordistributed modelcomplex flow control and retry logic
Async/Awaitperf, native modelcomplex internal logic

OS Threads for CPU intensive task (parallel computing), Async/Await for I/O intensive task (blocking I/O).

Threads

use std::thread;
use std::time::Duration;

fn main() {
let handle = thread::spawn(|| {
for i in 1..5 {
println!("hi number {} from the spawned thread!", i);
thread::sleep(Duration::from_millis(1));
}
});

for i in 1..5 {
println!("hi number {} from the main thread!", i);
thread::sleep(Duration::from_millis(1));
}

handle.join().unwrap();
}

Barrier

use std::sync::{Arc, Barrier};
use std::thread;

fn main() {
let mut handles = Vec::with_capacity(6);
let barrier = Arc::new(Barrier::new(6));

for _ in 0..6 {
let b = barrier.clone();
handles.push(thread::spawn(move|| {
println!("before wait");
b.wait();
println!("after wait");
}));
}

for handle in handles {
handle.join().unwrap();
}
}

Condition Variables and Mutex

use std::thread;
use std::sync::{Arc, Mutex, Condvar};

fn main() {
let pair = Arc::new((Mutex::new(false), Condvar::new()));
let pair2 = pair.clone();

thread::spawn(move|| {
let &(ref lock, ref cvar) = &*pair2;
let mut started = lock.lock().unwrap();
println!("changing started");
*started = true;
cvar.notify_one();
});

let &(ref lock, ref cvar) = &*pair;
let mut started = lock.lock().unwrap();
while !*started {
started = cvar.wait(started).unwrap();
}

println!("started changed");
}
use std::sync::{Arc, Mutex};
use std::thread;

fn main() {
let counter = Arc::new(Mutex::new(0));
let mut handles = vec![];

for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
let mut num = counter.lock().unwrap();
*num += 1;
});
handles.push(handle);
}

for handle in handles {
handle.join().unwrap();
}

println!("Result: {}", *counter.lock().unwrap());
}

Read and write mutex:

  • 同时允许多个读, 但最多只能有一个写.
  • 读和写不能同时存在.
  • 可以使用 read/try_read/write/try_write.
use std::sync::RwLock;

fn main() {
let lock = RwLock::new(5);

// 同一时间允许多个读.
{
let r1 = lock.read().unwrap();
let r2 = lock.read().unwrap();
assert_eq!(*r1, 5);
assert_eq!(*r2, 5);
} // Drop.

// 同一时间只允许一个写.
{
let mut w = lock.write().unwrap();
*w += 1;
assert_eq!(*w, 6);
} // Drop.
}

Threads Communication

Message channel:

use std::sync::mpsc;
use std::thread;

fn main() {
let (tx, rx) = mpsc::channel();

thread::spawn(move || {
tx.send(1).unwrap();
});

// Block.
println!("receive {}", rx.recv().unwrap());
}

Sync channel with message buffer:

use std::sync::mpsc;
use std::thread;
use std::time::Duration;

fn main() {
// Sync channel with 3 length buffer.
let (tx, rx)= mpsc::sync_channel(3);

let handle = thread::spawn(move || {
println!("发送之前");
tx.send(1).unwrap();
println!("发送之后");
});

println!("睡眠之前");
thread::sleep(Duration::from_secs(3));
println!("睡眠之后");

println!("receive {}", rx.recv().unwrap());
handle.join().unwrap();
}

Send message via for loop:

use std::sync::mpsc;
use std::thread;
use std::time::Duration;

fn main() {
let (tx, rx) = mpsc::channel();

thread::spawn(move || {
let values = vec![
String::from("hi"),
String::from("from"),
String::from("the"),
String::from("thread"),
];

for value in values {
tx.send(value).unwrap();
thread::sleep(Duration::from_secs(1));
}
});

for received in rx {
println!("Got: {}", received);
}
}

Multiple producers:

use std::sync::mpsc;
use std::thread;

fn main() {
let (tx, rx) = mpsc::channel();
let tx1 = tx.clone();

thread::spawn(move || {
tx.send(String::from("hi from raw tx")).unwrap();
});

thread::spawn(move || {
tx1.send(String::from("hi from cloned tx")).unwrap();
});

for received in rx {
println!("Got: {}", received);
}
}

Multiple type message:

use std::sync::mpsc::{self, Receiver, Sender};

enum Fruit {
Apple(u8),
Orange(String)
}

fn main() {
let (tx, rx): (Sender<Fruit>, Receiver<Fruit>) = mpsc::channel();

tx.send(Fruit::Orange("sweet".to_string())).unwrap();
tx.send(Fruit::Apple(2)).unwrap();

for _ in 0..2 {
match rx.recv().unwrap() {
Fruit::Apple(count) => println!("received {} apples", count),
Fruit::Orange(flavor) => println!("received {} oranges", flavor),
}
}
}

Tokio Semaphore

use std::sync::Arc;
use tokio::sync::Semaphore;

#[tokio::main]
async fn main() {
let semaphore = Arc::new(Semaphore::new(3));
let mut join_handles = Vec::new();

for _ in 0..5 {
let permit = semaphore.clone().acquire_owned().await.unwrap();
join_handles.push(tokio::spawn(async move {
/**
* Task here ...
*/
drop(permit);
}));
}

for handle in join_handles {
handle.await.unwrap();
}
}

Atomic Primitives

use std::ops::Sub;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Instant;

const N_TIMES: u64 = 10000000;
const N_THREADS: usize = 10;

static R: AtomicU64 = AtomicU64::new(0);

fn add_n_times(n: u64) -> JoinHandle<()> {
thread::spawn(move || {
for _ in 0..n {
R.fetch_add(1, Ordering::Relaxed);
}
})
}

fn main() {
let s = Instant::now();
let mut threads = Vec::with_capacity(N_THREADS);

for _ in 0..N_THREADS {
threads.push(add_n_times(N_TIMES));
}

for thread in threads {
thread.join().unwrap();
}

assert_eq!(N_TIMES * N_THREADS as u64, R.load(Ordering::Relaxed));

println!("{:?}",Instant::now().sub(s));
}

Ordering 内存顺序:

  • Relaxed: 乱序.
  • Release: 设置内存屏障, 保证它之前的操作永远在它之前.
  • Acquire: 设置内存屏障, 保证它之后的操作永远在它之后.
  • AcqRel: Acquire + Release.
  • SeqCst: 顺序一致性.
use std::thread::{self, JoinHandle};
use std::sync::atomic::{Ordering, AtomicBool};

static mut DATA: u64 = 0;
static READY: AtomicBool = AtomicBool::new(false);

fn reset() {
unsafe {
DATA = 0;
}
READY.store(false, Ordering::Relaxed);
}

fn producer() -> JoinHandle<()> {
thread::spawn(move || {
unsafe {
DATA = 100; // A
}
READY.store(true, Ordering::Release); // B: 内存屏障 ↑
})
}

fn consumer() -> JoinHandle<()> {
thread::spawn(move || {
while !READY.load(Ordering::Acquire) {} // C: 内存屏障 ↓

assert_eq!(100, unsafe { DATA }); // D
})
}

fn main() {
loop {
reset();

let t_producer = producer();
let t_consumer = consumer();

t_producer.join().unwrap();
t_consumer.join().unwrap();
}
}

Spinlock:

use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{hint, thread};

fn main() {
let spinlock = Arc::new(AtomicUsize::new(1));
let spinlock_clone = Arc::clone(&spinlock);
let thread = thread::spawn(move|| {
spinlock_clone.store(0, Ordering::SeqCst);
});

// 等待其它线程释放锁.
while spinlock.load(Ordering::SeqCst) != 0 {
hint::spin_loop();
}

if let Err(panic) = thread.join() {
println!("Thread had an error: {:?}", panic);
}
}

Send and Sync Trait

Send and Sync:

  • Marker trait.
  • 实现 Send 的类型可以在线程间安全的传递其所有权, 实现 Sync 的类型可以在线程间安全的共享 (通过引用). 若 &T: Send, 则 T: Sync.
  • 绝大部分类型都实现了 Send/Sync, 例外: 原生指针, Cell/RefCell, Rc.

Thread Pool

use std::sync::mpsc;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;

type Job = Box<dyn FnOnce() + Send + 'static>;

enum Message {
NewJob(Job),
Terminate,
}

pub struct ThreadPool {
workers: Vec<Worker>,
sender: mpsc::Sender<Message>,
}

impl ThreadPool {
/// Create a new ThreadPool.
///
/// The size is the number of threads in the pool.
///
/// # Panics
///
/// The `new` function will panic if the size is zero.
pub fn new(size: usize) -> ThreadPool {
assert!(size > 0);

let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);

for id in 0..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}

ThreadPool { workers, sender }
}

pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.sender.send(Message::NewJob(job)).unwrap();
}
}

impl Drop for ThreadPool {
fn drop(&mut self) {
println!("Sending terminate message to all workers.");

for _ in &self.workers {
self.sender.send(Message::Terminate).unwrap();
}

println!("Shutting down all workers.");

for worker in &mut self.workers {
println!("Shutting down worker {}", worker.id);

if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
}
}

struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Worker {
let thread = thread::spawn(move || loop {
let message = receiver.lock().unwrap().recv().unwrap();

match message {
Message::NewJob(job) => {
println!("Worker {} got a job; executing.", id);
job();
}
Message::Terminate => {
println!("Worker {} was told to terminate.", id);
break;
}
}
});

Worker {
id,
thread: Some(thread),
}
}
}

Async and Await

  • .await 执行期间, 任务可能会在线程间转移.
  • .await 只能用于 async fn 函数中.
use futures::executor::block_on;
use futures::join;

async fn learn_song() -> Song { /* ... */ }
async fn sing_song(song: Song) { /* ... */ }
async fn dance() { /* ... */ }

async fn learn_and_sing() {
let song = learn_song().await;
sing_song(song).await;
}

async fn async_main() {
let f1 = learn_and_sing();
let f2 = dance();
join!(f1, f2);
}

fn main() {
block_on(async_main());
}
use futures::future;
use futures::select;

pub fn main() {
let mut a_fut = future::ready(4);
let mut b_fut = future::ready(6);
let mut total = 0;

loop {
select! {
a = a_fut => total += a,
b = b_fut => total += b,
complete => break,
default => panic!(), // 该分支永远不会运行.
};
}

assert_eq!(total, 10);
}

Future Trait

  • Future 代表一组计算, 惰性求值. 当 .await 调用时才真正开始执行.
  • Future 启动后会因资源等原因阻塞, 转入 pending 状态.
  • Future 阻塞后, 当资源准备好可以重新启动时, 会通过 Waker.wake 通知执行器, 等待被下一次 poll.
trait Future {
type Output;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output>;
}

use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
thread,
time::Duration,
};

pub struct TimerFuture {
shared_state: Arc<Mutex<SharedState>>,
}

struct SharedState {
completed: bool,
waker: Option<Waker>,
}

impl Future for TimerFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared_state = self.shared_state.lock().unwrap();

if shared_state.completed {
Poll::Ready(())
} else {
shared_state.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

impl TimerFuture {
pub fn new(duration: Duration) -> Self {
let shared_state = Arc::new(Mutex::new(SharedState {
completed: false,
waker: None,
}));

let thread_shared_state = shared_state.clone();

thread::spawn(move || {
thread::sleep(duration);
let mut shared_state = thread_shared_state.lock().unwrap();
shared_state.completed = true;
if let Some(waker) = shared_state.waker.take() {
waker.wake()
}
});

TimerFuture { shared_state }
}
}

Future 会返回 Poll::Pending 时, 一定要确保 wake 能被正常调用, 否则会导致任务永远被挂起, 再也不会被执行器 poll.

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};

struct Delay {
when: Instant,
}

impl Future for Delay {
type Output = &'static str;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>)
-> Poll<&'static str>
{
if Instant::now() >= self.when {
println!("Hello world");
Poll::Ready("done")
} else {
let waker = cx.waker().clone();
let when = self.when;

thread::spawn(move || {
let now = Instant::now();

if now < when {
thread::sleep(when - now);
}

waker.wake();
});

Poll::Pending
}
}
}

#[tokio::main]
async fn main() {
let when = Instant::now() + Duration::from_millis(10);
let future = Delay { when };

let out = future.await;
assert_eq!(out, "done");
}

Asynchronous Runtime

use {
futures::{
future::{BoxFuture, FutureExt},
task::{waker_ref, ArcWake},
},
std::{
future::Future,
sync::mpsc::{sync_channel, Receiver, SyncSender},
sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
},
// 引入之前实现的定时器模块
timer_future::TimerFuture,
};

struct Task {
future: Mutex<Option<BoxFuture<'static, ()>>>,
task_sender: SyncSender<Arc<Task>>,
}

impl ArcWake for Task {
fn wake_by_ref(arc_self: &Arc<Self>) {
let cloned = arc_self.clone();
arc_self
.task_sender
.send(cloned)
.expect("任务队列已满");
}
}

#[derive(Clone)]
struct Spawner {
task_sender: SyncSender<Arc<Task>>,
}

impl Spawner {
fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) {
let future = future.boxed();
let task = Arc::new(Task {
future: Mutex::new(Some(future)),
task_sender: self.task_sender.clone(),
});
self.task_sender.send(task).expect("任务队列已满");
}
}

struct Executor {
ready_queue: Receiver<Arc<Task>>,
}

impl Executor {
fn run(&self) {
while let Ok(task) = self.ready_queue.recv() {
let mut future_slot = task.future.lock().unwrap();

if let Some(mut future) = future_slot.take() {
let waker = waker_ref(&task);
let context = &mut Context::from_waker(&*waker);

if future.as_mut().poll(context).is_pending() {
// Future 未执行完,, 将它放回任务中, 等待下次被 poll.
*future_slot = Some(future);
}
}
}
}
}

fn new_executor_and_spawner() -> (Executor, Spawner) {
const MAX_QUEUED_TASKS: usize = 10_000;
let (task_sender, ready_queue) = sync_channel(MAX_QUEUED_TASKS);
(Executor { ready_queue }, Spawner { task_sender })
}

fn main() {
let (executor, spawner) = new_executor_and_spawner();

spawner.spawn(async {
println!("howdy!");
TimerFuture::new(Duration::new(2, 0)).await;
println!("done!");
});

drop(spawner);

// 运行执行器直到任务队列为空.
// 任务运行后, 会先打印 `howdy!`, 暂停 2 秒, 接着打印 `done!`.
executor.run();
}
use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::thread;
use std::time::{Duration, Instant};
use futures::task::{self, ArcWake};
use crossbeam::channel;

fn main() {
let mini_tokio = MiniTokio::new();

mini_tokio.spawn(async {
spawn(async {
delay(Duration::from_millis(100)).await;
println!("world");
});

spawn(async {
println!("hello");
});

delay(Duration::from_millis(200)).await;
std::process::exit(0);
});

mini_tokio.run();
}

struct MiniTokio {
scheduled: channel::Receiver<Arc<Task>>,
sender: channel::Sender<Arc<Task>>,
}

impl MiniTokio {
fn new() -> MiniTokio {
let (sender, scheduled) = channel::unbounded();

MiniTokio { scheduled, sender }
}

fn spawn<F>(&self, future: F)
where
F: Future<Output = ()> + Send + 'static,
{
Task::spawn(future, &self.sender);
}

fn run(&self) {
CURRENT.with(|cell| {
*cell.borrow_mut() = Some(self.sender.clone());
});

while let Ok(task) = self.scheduled.recv() {
task.poll();
}
}
}

pub fn spawn<F>(future: F)
where
F: Future<Output = ()> + Send + 'static,
{
CURRENT.with(|cell| {
let borrow = cell.borrow();
let sender = borrow.as_ref().unwrap();
Task::spawn(future, sender);
});
}

async fn delay(dur: Duration) {
struct Delay {
when: Instant,
waker: Option<Arc<Mutex<Waker>>>,
}

impl Future for Delay {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if let Some(waker) = &self.waker {
let mut waker = waker.lock().unwrap();

if !waker.will_wake(cx.waker()) {
*waker = cx.waker().clone();
}
} else {
let when = self.when;
let waker = Arc::new(Mutex::new(cx.waker().clone()));
self.waker = Some(waker.clone());

thread::spawn(move || {
let now = Instant::now();

if now < when {
thread::sleep(when - now);
}

let waker = waker.lock().unwrap();
waker.wake_by_ref();
});
}

if Instant::now() >= self.when {
Poll::Ready(())
} else {
Poll::Pending
}
}
}

let future = Delay {
when: Instant::now() + dur,
waker: None,
};

future.await;
}

thread_local! {
static CURRENT: RefCell<Option<channel::Sender<Arc<Task>>>> =
RefCell::new(None);
}

struct Task {
future: Mutex<Pin<Box<dyn Future<Output = ()> + Send>>>,
executor: channel::Sender<Arc<Task>>,
}

impl Task {
fn spawn<F>(future: F, sender: &channel::Sender<Arc<Task>>)
where
F: Future<Output = ()> + Send + 'static,
{
let task = Arc::new(Task {
future: Mutex::new(Box::pin(future)),
executor: sender.clone(),
});

let _ = sender.send(task);
}

fn poll(self: Arc<Self>) {
let waker = task::waker(self.clone());
let mut cx = Context::from_waker(&waker);
let mut future = self.future.try_lock().unwrap();
let _ = future.as_mut().poll(&mut cx);
}
}

impl ArcWake for Task {
fn wake_by_ref(arc_self: &Arc<Self>) {
let _ = arc_self.executor.send(arc_self.clone());
}
}
#[tokio::main]
async fn main() {
println!("Hello world");
}

fn main() {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
println!("Hello world");
})
}