Skip to content

Rust 多线程编程完全指南

📚 一、Rust 多线程核心概念

1. Rust 的线程安全保证

Rust 通过所有权系统类型系统编译时防止数据竞争,这是 Rust 最强大的特性之一。

// Rust 保证:如果代码编译通过,就没有数据竞争
fn main() {
    // 这是线程安全的基石
}

🎯 二、创建和管理线程

1. 基本线程创建

use std::thread;
use std::time::Duration;

fn basic_threads() {
    println!("主线程开始");

    // 创建新线程
    let handle = thread::spawn(|| {
        for i in 1..=5 {
            println!("子线程: 计数 {}", i);
            thread::sleep(Duration::from_millis(500));
        }
    });

    // 主线程继续执行
    for i in 1..=3 {
        println!("主线程: 工作 {}", i);
        thread::sleep(Duration::from_millis(300));
    }

    // 等待子线程完成
    handle.join().unwrap();
    println!("所有线程完成");
}

2. 使用 move 闭包获取所有权

fn move_closure_threads() {
    let data = vec![1, 2, 3, 4, 5];

    // 使用 move 将 data 的所有权转移到线程中
    let handle = thread::spawn(move || {
        println!("线程中的数据: {:?}", data);
        // 这里可以安全地使用 data
    });

    // 这里不能再使用 data,因为所有权已转移
    // println!("{:?}", data);  // ❌ 编译错误

    handle.join().unwrap();
}

3. 线程句柄和 Join

fn thread_handles() {
    let mut handles = vec![];

    // 创建多个线程
    for i in 0..5 {
        let handle = thread::spawn(move || {
            println!("线程 {} 开始", i);
            thread::sleep(Duration::from_millis(i * 100));
            println!("线程 {} 结束", i);
            i * 2  // 线程返回值
        });
        handles.push(handle);
    }

    // 收集所有线程的结果
    let results: Vec<i32> = handles
        .into_iter()
        .filter_map(|h| h.join().ok())
        .collect();

    println!("所有线程结果: {:?}", results);
}

🔄 三、线程间通信

1. 通道(Channel)基础

use std::sync::mpsc;  // 多生产者,单消费者

fn basic_channels() {
    // 创建通道
    let (tx, rx) = mpsc::channel();

    // 创建发送线程
    let sender = thread::spawn(move || {
        for i in 0..5 {
            tx.send(i).unwrap();
            println!("发送: {}", i);
            thread::sleep(Duration::from_millis(100));
        }
    });

    // 在主线程接收
    for received in rx {
        println!("接收: {}", received);
    }

    sender.join().unwrap();
}

2. 多生产者

fn multiple_producers() {
    let (tx, rx) = mpsc::channel();

    // 创建多个发送者
    let tx1 = tx.clone();
    let tx2 = tx.clone();

    // 线程1
    let producer1 = thread::spawn(move || {
        for i in 0..3 {
            tx1.send(format!("线程1-消息{}", i)).unwrap();
            thread::sleep(Duration::from_millis(50));
        }
    });

    // 线程2
    let producer2 = thread::spawn(move || {
        for i in 0..3 {
            tx2.send(format!("线程2-消息{}", i)).unwrap();
            thread::sleep(Duration::from_millis(30));
        }
    });

    // 在主线程接收所有消息
    drop(tx);  // 关闭原始发送者

    for msg in rx {
        println!("收到: {}", msg);
    }

    producer1.join().unwrap();
    producer2.join().unwrap();
}

3. 同步通道和异步通道

fn sync_vs_async_channels() {
    use std::sync::mpsc;

    // 异步通道(无界缓冲区)
    let (async_tx, async_rx) = mpsc::channel();

    // 同步通道(有界缓冲区)
    let (sync_tx, sync_rx) = mpsc::sync_channel(3);  // 缓冲区大小为3

    // 异步发送不会阻塞
    let async_sender = thread::spawn(move || {
        for i in 0..10 {
            async_tx.send(i).unwrap();
            println!("异步发送: {}", i);
        }
    });

    // 同步发送在缓冲区满时会阻塞
    let sync_sender = thread::spawn(move || {
        for i in 0..10 {
            sync_tx.send(i).unwrap();
            println!("同步发送: {}", i);
            thread::sleep(Duration::from_millis(50));
        }
    });

    // 接收
    let receiver = thread::spawn(move || {
        thread::sleep(Duration::from_millis(100));
        for msg in async_rx {
            println!("异步接收: {}", msg);
        }

        for msg in sync_rx {
            println!("同步接收: {}", msg);
        }
    });

    async_sender.join().unwrap();
    sync_sender.join().unwrap();
    receiver.join().unwrap();
}

🔐 四、共享状态并发

1. 互斥锁(Mutex)

use std::sync::{Arc, Mutex};

fn mutex_example() {
    // 使用 Arc 实现线程间共享所有权
    let counter = Arc::new(Mutex::new(0));
    let mut handles = vec![];

    for _ in 0..10 {
        let counter = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            // 获取锁
            let mut num = counter.lock().unwrap();
            *num += 1;
            // 锁在离开作用域时自动释放
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("最终计数: {}", *counter.lock().unwrap());
}

2. 读写锁(RwLock)

use std::sync::{Arc, RwLock};
use std::time::Duration;

fn rwlock_example() {
    let data = Arc::new(RwLock::new(vec![1, 2, 3]));
    let mut handles = vec![];

    // 多个读线程
    for i in 0..5 {
        let data = Arc::clone(&data);
        let handle = thread::spawn(move || {
            // 获取读锁
            let reader = data.read().unwrap();
            println!("读线程{}: {:?}", i, *reader);
            thread::sleep(Duration::from_millis(10));
        });
        handles.push(handle);
    }

    // 写线程
    let data_write = Arc::clone(&data);
    let write_handle = thread::spawn(move || {
        // 获取写锁
        let mut writer = data_write.write().unwrap();
        writer.push(4);
        writer.push(5);
        println!("写线程: 添加了新元素");
    });
    handles.push(write_handle);

    for handle in handles {
        handle.join().unwrap();
    }

    println!("最终数据: {:?}", *data.read().unwrap());
}

3. 条件变量(Condvar)

use std::sync::{Arc, Mutex, Condvar};

fn condvar_example() {
    let pair = Arc::new((Mutex::new(false), Condvar::new()));
    let pair2 = Arc::clone(&pair);

    // 等待线程
    let waiter = thread::spawn(move || {
        let (lock, cvar) = &*pair2;
        let mut started = lock.lock().unwrap();

        println!("等待线程: 等待条件...");
        while !*started {
            started = cvar.wait(started).unwrap();
        }

        println!("等待线程: 条件满足!");
    });

    // 通知线程
    thread::sleep(Duration::from_secs(1));

    let (lock, cvar) = &*pair;
    let mut started = lock.lock().unwrap();
    *started = true;
    println!("通知线程: 通知所有等待者");
    cvar.notify_all();

    waiter.join().unwrap();
}

⚡ 五、原子操作

1. 原子类型

use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

fn atomic_example() {
    let counter = Arc::new(AtomicUsize::new(0));
    let mut handles = vec![];

    for _ in 0..10 {
        let counter = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            for _ in 0..1000 {
                // 原子增加,不需要锁
                counter.fetch_add(1, Ordering::SeqCst);
            }
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("最终计数: {}", counter.load(Ordering::SeqCst));
}

2. 内存排序

fn memory_ordering() {
    use std::sync::atomic::{AtomicBool, AtomicUsize};
    use std::sync::Arc;

    let flag = Arc::new(AtomicBool::new(false));
    let data = Arc::new(AtomicUsize::new(0));

    let flag_clone = Arc::clone(&flag);
    let data_clone = Arc::clone(&data);

    // 生产者线程
    let producer = thread::spawn(move || {
        data_clone.store(42, Ordering::Relaxed);
        flag_clone.store(true, Ordering::Release);  // 释放语义
    });

    // 消费者线程
    let consumer = thread::spawn(move || {
        while !flag.load(Ordering::Acquire) {  // 获取语义
            // 忙等待
        }
        let value = data.load(Ordering::Relaxed);
        println!("读取的值: {}", value);
    });

    producer.join().unwrap();
    consumer.join().unwrap();
}

🧵 六、线程池

1. 简单线程池实现

use std::sync::{Arc, Mutex, mpsc};
use std::thread;

type Job = Box<dyn FnOnce() + Send + 'static>;

struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || loop {
            let message = receiver.lock().unwrap().recv();

            match message {
                Ok(job) => {
                    println!("Worker {} 开始执行任务", id);
                    job();
                }
                Err(_) => {
                    println!("Worker {} 关闭", id);
                    break;
                }
            }
        });

        Worker {
            id,
            thread: Some(thread),
        }
    }
}

pub struct ThreadPool {
    workers: Vec<Worker>,
    sender: Option<mpsc::Sender<Job>>,
}

impl ThreadPool {
    pub fn new(size: usize) -> ThreadPool {
        assert!(size > 0);

        let (sender, receiver) = mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);

        for id in 0..size {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }

        ThreadPool {
            workers,
            sender: Some(sender),
        }
    }

    pub fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.as_ref().unwrap().send(job).unwrap();
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        drop(self.sender.take());

        for worker in &mut self.workers {
            println!("关闭 worker {}", worker.id);

            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

fn thread_pool_example() {
    let pool = ThreadPool::new(4);

    for i in 0..8 {
        pool.execute(move || {
            println!("执行任务 {}", i);
            thread::sleep(Duration::from_millis(500));
        });
    }

    thread::sleep(Duration::from_secs(3));
}

🔄 七、异步编程基础

1. 使用 Tokio

// Cargo.toml: tokio = { version = "1.0", features = ["full"] }

#[tokio::main]
async fn tokio_example() {
    use tokio::time::{sleep, Duration};

    // 创建异步任务
    let task1 = async {
        sleep(Duration::from_secs(1)).await;
        println!("任务1完成");
        1
    };

    let task2 = async {
        sleep(Duration::from_secs(2)).await;
        println!("任务2完成");
        2
    };

    // 并发执行
    let (result1, result2) = tokio::join!(task1, task2);
    println!("结果: {}, {}", result1, result2);
}

2. 异步通道

#[tokio::main]
async fn async_channels() {
    use tokio::sync::mpsc;

    let (tx, mut rx) = mpsc::channel(100);

    // 生产者任务
    let producer = tokio::spawn(async move {
        for i in 0..10 {
            tx.send(i).await.unwrap();
            tokio::time::sleep(Duration::from_millis(100)).await;
        }
    });

    // 消费者任务
    let consumer = tokio::spawn(async move {
        while let Some(value) = rx.recv().await {
            println!("收到: {}", value);
        }
    });

    let _ = tokio::join!(producer, consumer);
}

📊 八、线程局部存储

1. Thread Local Storage

use std::cell::RefCell;
use std::thread;

thread_local! {
    static COUNTER: RefCell<u32> = RefCell::new(0);
}

fn thread_local_example() {
    // 在主线程中设置
    COUNTER.with(|counter| {
        *counter.borrow_mut() = 42;
    });

    // 创建新线程,每个线程有自己的副本
    let handles: Vec<_> = (0..3)
        .map(|i| {
            thread::spawn(move || {
                COUNTER.with(|counter| {
                    *counter.borrow_mut() = i as u32;
                    println!("线程 {}: 计数器 = {}", i, *counter.borrow());
                });
            })
        })
        .collect();

    for handle in handles {
        handle.join().unwrap();
    }

    // 主线程的计数器不变
    COUNTER.with(|counter| {
        println!("主线程计数器: {}", *counter.borrow());
    });
}

⚠️ 九、常见问题和解决方案

1. 死锁预防

use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;

fn deadlock_prevention() {
    let resource1 = Arc::new(Mutex::new(0));
    let resource2 = Arc::new(Mutex::new(0));

    // 预防死锁:总是以相同顺序获取锁
    let r1 = Arc::clone(&resource1);
    let r2 = Arc::clone(&resource2);

    let thread1 = thread::spawn(move || {
        // 总是先获取 resource1,再获取 resource2
        let _lock1 = r1.lock().unwrap();
        thread::sleep(Duration::from_millis(10));
        let _lock2 = r2.lock().unwrap();
        println!("线程1获取了两个锁");
    });

    let thread2 = thread::spawn(move || {
        // 也以相同顺序获取
        let _lock1 = resource1.lock().unwrap();
        thread::sleep(Duration::from_millis(10));
        let _lock2 = resource2.lock().unwrap();
        println!("线程2获取了两个锁");
    });

    thread1.join().unwrap();
    thread2.join().unwrap();
}

2. 避免锁中毒

use std::sync::{Arc, Mutex, PoisonError};
use std::thread;

fn poison_prevention() {
    let data = Arc::new(Mutex::new(0));

    let data_clone = Arc::clone(&data);
    let thread = thread::spawn(move || {
        let mut guard = data_clone.lock().unwrap();
        *guard += 1;
        panic!("线程崩溃!");  // 这会使锁中毒
    });

    // 等待线程完成
    let _ = thread.join();

    // 处理可能中毒的锁
    match data.lock() {
        Ok(mut guard) => {
            *guard += 1;
            println!("成功获取锁: {}", *guard);
        }
        Err(PoisonError { guard, .. }) => {
            println!("锁已中毒,但我们可以继续使用");
            let mut guard = guard.into_inner();
            *guard += 1;
            println!("恢复后的值: {}", *guard);
        }
    }
}

🎯 十、性能优化技巧

1. 减少锁竞争

use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Instant;

fn reduce_lock_contention() {
    let data = Arc::new(Mutex::new(vec![]));
    let mut handles = vec![];

    let start = Instant::now();

    for i in 0..10 {
        let data = Arc::clone(&data);
        let handle = thread::spawn(move || {
            // ❌ 不好的做法:频繁获取释放锁
            for j in 0..1000 {
                let mut vec = data.lock().unwrap();
                vec.push((i, j));
            }
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("频繁锁耗时: {:?}", start.elapsed());

    // ✅ 好的做法:批量处理
    let data2 = Arc::new(Mutex::new(vec![]));
    let mut handles2 = vec![];

    let start2 = Instant::now();

    for i in 0..10 {
        let data = Arc::clone(&data2);
        let handle = thread::spawn(move || {
            let mut local_vec = vec![];
            for j in 0..1000 {
                local_vec.push((i, j));
            }
            // 一次性添加到共享数据
            let mut vec = data.lock().unwrap();
            vec.extend(local_vec);
        });
        handles2.push(handle);
    }

    for handle in handles2 {
        handle.join().unwrap();
    }

    println!("批量处理耗时: {:?}", start2.elapsed());
}

2. 无锁数据结构

use crossbeam::queue::ArrayQueue;
use std::thread;

fn lock_free_example() {
    let queue = Arc::new(ArrayQueue::new(100));

    // 生产者
    let producer = thread::spawn({
        let queue = Arc::clone(&queue);
        move || {
            for i in 0..50 {
                while queue.push(i).is_err() {
                    thread::yield_now();
                }
            }
        }
    });

    // 消费者
    let consumer = thread::spawn({
        let queue = Arc::clone(&queue);
        move || {
            let mut count = 0;
            while count < 50 {
                if let Some(value) = queue.pop() {
                    println!("消费: {}", value);
                    count += 1;
                } else {
                    thread::yield_now();
                }
            }
        }
    });

    producer.join().unwrap();
    consumer.join().unwrap();
}

📋 十一、完整示例:Web 服务器线程池

use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::time::Duration;

struct ThreadPool {
    workers: Vec<Worker>,
    sender: Option<std::sync::mpsc::Sender<Job>>,
}

type Job = Box<dyn FnOnce() + Send + 'static>;

impl ThreadPool {
    fn new(size: usize) -> ThreadPool {
        let (sender, receiver) = std::sync::mpsc::channel();
        let receiver = std::sync::Arc::new(std::sync::Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);

        for id in 0..size {
            workers.push(Worker::new(id, std::sync::Arc::clone(&receiver)));
        }

        ThreadPool {
            workers,
            sender: Some(sender),
        }
    }

    fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.as_ref().unwrap().send(job).unwrap();
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        drop(self.sender.take());

        for worker in &mut self.workers {
            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
    fn new(id: usize, receiver: std::sync::Arc<std::sync::Mutex<std::sync::mpsc::Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || loop {
            let message = receiver.lock().unwrap().recv();

            match message {
                Ok(job) => {
                    println!("Worker {} 开始处理请求", id);
                    job();
                }
                Err(_) => {
                    println!("Worker {} 关闭", id);
                    break;
                }
            }
        });

        Worker {
            id,
            thread: Some(thread),
        }
    }
}

fn handle_connection(mut stream: TcpStream) {
    let mut buffer = [0; 1024];
    stream.read(&mut buffer).unwrap();

    let get = b"GET / HTTP/1.1\r\n";
    let sleep = b"GET /sleep HTTP/1.1\r\n";

    let (status_line, contents) = if buffer.starts_with(get) {
        ("HTTP/1.1 200 OK", "<h1>你好!</h1>")
    } else if buffer.starts_with(sleep) {
        thread::sleep(Duration::from_secs(5));
        ("HTTP/1.1 200 OK", "<h1>睡醒了!</h1>")
    } else {
        ("HTTP/1.1 404 NOT FOUND", "<h1>页面未找到</h1>")
    };

    let response = format!(
        "{}\r\nContent-Length: {}\r\n\r\n{}",
        status_line,
        contents.len(),
        contents
    );

    stream.write_all(response.as_bytes()).unwrap();
    stream.flush().unwrap();
}

fn main() {
    let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
    let pool = ThreadPool::new(4);

    for stream in listener.incoming() {
        let stream = stream.unwrap();

        pool.execute(|| {
            handle_connection(stream);
        });
    }

    println!("服务器关闭");
}

🎯 十二、最佳实践总结

1. 选择正确的工具

场景 推荐工具 原因
任务并行 rayon 简单、自动并行化
I/O 密集型 tokio 异步、高并发
CPU 密集型 原生线程 控制更精细
简单通信 std::sync::mpsc 标准库、简单
复杂通信 crossbeam 功能更丰富
共享状态 Arc<Mutex<T>> 标准、安全
无锁操作 原子类型 性能更高

2. 避免常见陷阱

fn avoid_pitfalls() {
    // ❌ 陷阱1:过度使用线程
    // 不要为每个小任务创建线程

    // ❌ 陷阱2:忘记 Join
    // thread::spawn(|| { ... });  // 忘记保存句柄

    // ✅ 正确做法
    let handle = thread::spawn(|| { /* ... */ });
    handle.join().unwrap();

    // ❌ 陷阱3:共享可变状态无保护
    // let data = vec![1, 2, 3];
    // thread::spawn(|| data.push(4));  // 编译错误,Rust 会阻止

    // ✅ 正确做法
    use std::sync::{Arc, Mutex};
    let data = Arc::new(Mutex::new(vec![1, 2, 3]));
    let data_clone = Arc::clone(&data);
    thread::spawn(move || {
        data_clone.lock().unwrap().push(4);
    });
}

3. 性能模式

fn performance_patterns() {
    // 1. 工作窃取模式
    use rayon;
    let result: Vec<_> = (0..1000)
        .into_par_iter()  // 使用 Rayon 并行迭代
        .map(|x| x * 2)
        .collect();

    // 2. 分治模式
    fn parallel_sum(data: &[i32]) -> i32 {
        if data.len() <= 1000 {
            data.iter().sum()
        } else {
            let mid = data.len() / 2;
            let (left, right) = data.split_at(mid);

            let (left_sum, right_sum) = rayon::join(
                || parallel_sum(left),
                || parallel_sum(right),
            );

            left_sum + right_sum
        }
    }

    // 3. 流水线模式
    use crossbeam::channel;

    let (stage1_tx, stage1_rx) = channel::unbounded();
    let (stage2_tx, stage2_rx) = channel::unbounded();

    // 第一阶段
    thread::spawn(move || {
        for i in 0..100 {
            stage1_tx.send(i).unwrap();
        }
    });

    // 第二阶段
    thread::spawn(move || {
        for i in stage1_rx {
            stage2_tx.send(i * 2).unwrap();
        }
    });

    // 第三阶段
    thread::spawn(move || {
        for i in stage2_rx {
            println!("结果: {}", i);
        }
    });
}

📦 十三、Cargo.toml 依赖

[package]
name = "rust_concurrency"
version = "1.0.0"
edition = "2021"

[dependencies]
# 异步运行时
tokio = { version = "1.0", features = ["full"], optional = true }

# 并行数据处理
rayon = { version = "1.5", optional = true }

# 无锁数据结构
crossbeam = { version = "0.8", optional = true }

# 性能测量
criterion = { version = "0.3", optional = true }

[features]
default = []
full = ["tokio", "rayon", "crossbeam"]
async = ["tokio"]
parallel = ["rayon"]
lock-free = ["crossbeam"]
bench = ["criterion"]

[dev-dependencies]
criterion = "0.3"

💡 总结要点

  1. Rust 保证线程安全:编译时防止数据竞争
  2. 多种并发模型:线程、异步、无锁、数据并行
  3. 选择合适的工具:根据任务类型选择最佳方案
  4. 避免共享可变状态:尽量使用消息传递
  5. 性能优化:减少锁竞争,使用适当的数据结构
  6. 错误处理:正确处理锁中毒和线程崩溃

Rust 的多线程编程虽然学习曲线较陡,但一旦掌握,就能编写出既安全又高效的并发程序。编译器会帮你避免大部分常见的并发错误,让你可以专注于业务逻辑。