Rust 多线程编程完全指南
📚 一、Rust 多线程核心概念
1. Rust 的线程安全保证
Rust 通过所有权系统和类型系统在编译时防止数据竞争,这是 Rust 最强大的特性之一。
// Rust 保证:如果代码编译通过,就没有数据竞争
fn main() {
// 这是线程安全的基石
}
🎯 二、创建和管理线程
1. 基本线程创建
use std::thread;
use std::time::Duration;
fn basic_threads() {
println!("主线程开始");
// 创建新线程
let handle = thread::spawn(|| {
for i in 1..=5 {
println!("子线程: 计数 {}", i);
thread::sleep(Duration::from_millis(500));
}
});
// 主线程继续执行
for i in 1..=3 {
println!("主线程: 工作 {}", i);
thread::sleep(Duration::from_millis(300));
}
// 等待子线程完成
handle.join().unwrap();
println!("所有线程完成");
}
2. 使用 move 闭包获取所有权
fn move_closure_threads() {
let data = vec![1, 2, 3, 4, 5];
// 使用 move 将 data 的所有权转移到线程中
let handle = thread::spawn(move || {
println!("线程中的数据: {:?}", data);
// 这里可以安全地使用 data
});
// 这里不能再使用 data,因为所有权已转移
// println!("{:?}", data); // ❌ 编译错误
handle.join().unwrap();
}
3. 线程句柄和 Join
fn thread_handles() {
let mut handles = vec![];
// 创建多个线程
for i in 0..5 {
let handle = thread::spawn(move || {
println!("线程 {} 开始", i);
thread::sleep(Duration::from_millis(i * 100));
println!("线程 {} 结束", i);
i * 2 // 线程返回值
});
handles.push(handle);
}
// 收集所有线程的结果
let results: Vec<i32> = handles
.into_iter()
.filter_map(|h| h.join().ok())
.collect();
println!("所有线程结果: {:?}", results);
}
🔄 三、线程间通信
1. 通道(Channel)基础
use std::sync::mpsc; // 多生产者,单消费者
fn basic_channels() {
// 创建通道
let (tx, rx) = mpsc::channel();
// 创建发送线程
let sender = thread::spawn(move || {
for i in 0..5 {
tx.send(i).unwrap();
println!("发送: {}", i);
thread::sleep(Duration::from_millis(100));
}
});
// 在主线程接收
for received in rx {
println!("接收: {}", received);
}
sender.join().unwrap();
}
2. 多生产者
fn multiple_producers() {
let (tx, rx) = mpsc::channel();
// 创建多个发送者
let tx1 = tx.clone();
let tx2 = tx.clone();
// 线程1
let producer1 = thread::spawn(move || {
for i in 0..3 {
tx1.send(format!("线程1-消息{}", i)).unwrap();
thread::sleep(Duration::from_millis(50));
}
});
// 线程2
let producer2 = thread::spawn(move || {
for i in 0..3 {
tx2.send(format!("线程2-消息{}", i)).unwrap();
thread::sleep(Duration::from_millis(30));
}
});
// 在主线程接收所有消息
drop(tx); // 关闭原始发送者
for msg in rx {
println!("收到: {}", msg);
}
producer1.join().unwrap();
producer2.join().unwrap();
}
3. 同步通道和异步通道
fn sync_vs_async_channels() {
use std::sync::mpsc;
// 异步通道(无界缓冲区)
let (async_tx, async_rx) = mpsc::channel();
// 同步通道(有界缓冲区)
let (sync_tx, sync_rx) = mpsc::sync_channel(3); // 缓冲区大小为3
// 异步发送不会阻塞
let async_sender = thread::spawn(move || {
for i in 0..10 {
async_tx.send(i).unwrap();
println!("异步发送: {}", i);
}
});
// 同步发送在缓冲区满时会阻塞
let sync_sender = thread::spawn(move || {
for i in 0..10 {
sync_tx.send(i).unwrap();
println!("同步发送: {}", i);
thread::sleep(Duration::from_millis(50));
}
});
// 接收
let receiver = thread::spawn(move || {
thread::sleep(Duration::from_millis(100));
for msg in async_rx {
println!("异步接收: {}", msg);
}
for msg in sync_rx {
println!("同步接收: {}", msg);
}
});
async_sender.join().unwrap();
sync_sender.join().unwrap();
receiver.join().unwrap();
}
🔐 四、共享状态并发
1. 互斥锁(Mutex)
use std::sync::{Arc, Mutex};
fn mutex_example() {
// 使用 Arc 实现线程间共享所有权
let counter = Arc::new(Mutex::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
// 获取锁
let mut num = counter.lock().unwrap();
*num += 1;
// 锁在离开作用域时自动释放
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("最终计数: {}", *counter.lock().unwrap());
}
2. 读写锁(RwLock)
use std::sync::{Arc, RwLock};
use std::time::Duration;
fn rwlock_example() {
let data = Arc::new(RwLock::new(vec![1, 2, 3]));
let mut handles = vec![];
// 多个读线程
for i in 0..5 {
let data = Arc::clone(&data);
let handle = thread::spawn(move || {
// 获取读锁
let reader = data.read().unwrap();
println!("读线程{}: {:?}", i, *reader);
thread::sleep(Duration::from_millis(10));
});
handles.push(handle);
}
// 写线程
let data_write = Arc::clone(&data);
let write_handle = thread::spawn(move || {
// 获取写锁
let mut writer = data_write.write().unwrap();
writer.push(4);
writer.push(5);
println!("写线程: 添加了新元素");
});
handles.push(write_handle);
for handle in handles {
handle.join().unwrap();
}
println!("最终数据: {:?}", *data.read().unwrap());
}
3. 条件变量(Condvar)
use std::sync::{Arc, Mutex, Condvar};
fn condvar_example() {
let pair = Arc::new((Mutex::new(false), Condvar::new()));
let pair2 = Arc::clone(&pair);
// 等待线程
let waiter = thread::spawn(move || {
let (lock, cvar) = &*pair2;
let mut started = lock.lock().unwrap();
println!("等待线程: 等待条件...");
while !*started {
started = cvar.wait(started).unwrap();
}
println!("等待线程: 条件满足!");
});
// 通知线程
thread::sleep(Duration::from_secs(1));
let (lock, cvar) = &*pair;
let mut started = lock.lock().unwrap();
*started = true;
println!("通知线程: 通知所有等待者");
cvar.notify_all();
waiter.join().unwrap();
}
⚡ 五、原子操作
1. 原子类型
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
fn atomic_example() {
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
for _ in 0..1000 {
// 原子增加,不需要锁
counter.fetch_add(1, Ordering::SeqCst);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("最终计数: {}", counter.load(Ordering::SeqCst));
}
2. 内存排序
fn memory_ordering() {
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::sync::Arc;
let flag = Arc::new(AtomicBool::new(false));
let data = Arc::new(AtomicUsize::new(0));
let flag_clone = Arc::clone(&flag);
let data_clone = Arc::clone(&data);
// 生产者线程
let producer = thread::spawn(move || {
data_clone.store(42, Ordering::Relaxed);
flag_clone.store(true, Ordering::Release); // 释放语义
});
// 消费者线程
let consumer = thread::spawn(move || {
while !flag.load(Ordering::Acquire) { // 获取语义
// 忙等待
}
let value = data.load(Ordering::Relaxed);
println!("读取的值: {}", value);
});
producer.join().unwrap();
consumer.join().unwrap();
}
🧵 六、线程池
1. 简单线程池实现
use std::sync::{Arc, Mutex, mpsc};
use std::thread;
type Job = Box<dyn FnOnce() + Send + 'static>;
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
let thread = thread::spawn(move || loop {
let message = receiver.lock().unwrap().recv();
match message {
Ok(job) => {
println!("Worker {} 开始执行任务", id);
job();
}
Err(_) => {
println!("Worker {} 关闭", id);
break;
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}
pub struct ThreadPool {
workers: Vec<Worker>,
sender: Option<mpsc::Sender<Job>>,
}
impl ThreadPool {
pub fn new(size: usize) -> ThreadPool {
assert!(size > 0);
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
for id in 0..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}
ThreadPool {
workers,
sender: Some(sender),
}
}
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.sender.as_ref().unwrap().send(job).unwrap();
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
drop(self.sender.take());
for worker in &mut self.workers {
println!("关闭 worker {}", worker.id);
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
}
}
fn thread_pool_example() {
let pool = ThreadPool::new(4);
for i in 0..8 {
pool.execute(move || {
println!("执行任务 {}", i);
thread::sleep(Duration::from_millis(500));
});
}
thread::sleep(Duration::from_secs(3));
}
🔄 七、异步编程基础
1. 使用 Tokio
// Cargo.toml: tokio = { version = "1.0", features = ["full"] }
#[tokio::main]
async fn tokio_example() {
use tokio::time::{sleep, Duration};
// 创建异步任务
let task1 = async {
sleep(Duration::from_secs(1)).await;
println!("任务1完成");
1
};
let task2 = async {
sleep(Duration::from_secs(2)).await;
println!("任务2完成");
2
};
// 并发执行
let (result1, result2) = tokio::join!(task1, task2);
println!("结果: {}, {}", result1, result2);
}
2. 异步通道
#[tokio::main]
async fn async_channels() {
use tokio::sync::mpsc;
let (tx, mut rx) = mpsc::channel(100);
// 生产者任务
let producer = tokio::spawn(async move {
for i in 0..10 {
tx.send(i).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
}
});
// 消费者任务
let consumer = tokio::spawn(async move {
while let Some(value) = rx.recv().await {
println!("收到: {}", value);
}
});
let _ = tokio::join!(producer, consumer);
}
📊 八、线程局部存储
1. Thread Local Storage
use std::cell::RefCell;
use std::thread;
thread_local! {
static COUNTER: RefCell<u32> = RefCell::new(0);
}
fn thread_local_example() {
// 在主线程中设置
COUNTER.with(|counter| {
*counter.borrow_mut() = 42;
});
// 创建新线程,每个线程有自己的副本
let handles: Vec<_> = (0..3)
.map(|i| {
thread::spawn(move || {
COUNTER.with(|counter| {
*counter.borrow_mut() = i as u32;
println!("线程 {}: 计数器 = {}", i, *counter.borrow());
});
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
// 主线程的计数器不变
COUNTER.with(|counter| {
println!("主线程计数器: {}", *counter.borrow());
});
}
⚠️ 九、常见问题和解决方案
1. 死锁预防
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
fn deadlock_prevention() {
let resource1 = Arc::new(Mutex::new(0));
let resource2 = Arc::new(Mutex::new(0));
// 预防死锁:总是以相同顺序获取锁
let r1 = Arc::clone(&resource1);
let r2 = Arc::clone(&resource2);
let thread1 = thread::spawn(move || {
// 总是先获取 resource1,再获取 resource2
let _lock1 = r1.lock().unwrap();
thread::sleep(Duration::from_millis(10));
let _lock2 = r2.lock().unwrap();
println!("线程1获取了两个锁");
});
let thread2 = thread::spawn(move || {
// 也以相同顺序获取
let _lock1 = resource1.lock().unwrap();
thread::sleep(Duration::from_millis(10));
let _lock2 = resource2.lock().unwrap();
println!("线程2获取了两个锁");
});
thread1.join().unwrap();
thread2.join().unwrap();
}
2. 避免锁中毒
use std::sync::{Arc, Mutex, PoisonError};
use std::thread;
fn poison_prevention() {
let data = Arc::new(Mutex::new(0));
let data_clone = Arc::clone(&data);
let thread = thread::spawn(move || {
let mut guard = data_clone.lock().unwrap();
*guard += 1;
panic!("线程崩溃!"); // 这会使锁中毒
});
// 等待线程完成
let _ = thread.join();
// 处理可能中毒的锁
match data.lock() {
Ok(mut guard) => {
*guard += 1;
println!("成功获取锁: {}", *guard);
}
Err(PoisonError { guard, .. }) => {
println!("锁已中毒,但我们可以继续使用");
let mut guard = guard.into_inner();
*guard += 1;
println!("恢复后的值: {}", *guard);
}
}
}
🎯 十、性能优化技巧
1. 减少锁竞争
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Instant;
fn reduce_lock_contention() {
let data = Arc::new(Mutex::new(vec![]));
let mut handles = vec![];
let start = Instant::now();
for i in 0..10 {
let data = Arc::clone(&data);
let handle = thread::spawn(move || {
// ❌ 不好的做法:频繁获取释放锁
for j in 0..1000 {
let mut vec = data.lock().unwrap();
vec.push((i, j));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("频繁锁耗时: {:?}", start.elapsed());
// ✅ 好的做法:批量处理
let data2 = Arc::new(Mutex::new(vec![]));
let mut handles2 = vec![];
let start2 = Instant::now();
for i in 0..10 {
let data = Arc::clone(&data2);
let handle = thread::spawn(move || {
let mut local_vec = vec![];
for j in 0..1000 {
local_vec.push((i, j));
}
// 一次性添加到共享数据
let mut vec = data.lock().unwrap();
vec.extend(local_vec);
});
handles2.push(handle);
}
for handle in handles2 {
handle.join().unwrap();
}
println!("批量处理耗时: {:?}", start2.elapsed());
}
2. 无锁数据结构
use crossbeam::queue::ArrayQueue;
use std::thread;
fn lock_free_example() {
let queue = Arc::new(ArrayQueue::new(100));
// 生产者
let producer = thread::spawn({
let queue = Arc::clone(&queue);
move || {
for i in 0..50 {
while queue.push(i).is_err() {
thread::yield_now();
}
}
}
});
// 消费者
let consumer = thread::spawn({
let queue = Arc::clone(&queue);
move || {
let mut count = 0;
while count < 50 {
if let Some(value) = queue.pop() {
println!("消费: {}", value);
count += 1;
} else {
thread::yield_now();
}
}
}
});
producer.join().unwrap();
consumer.join().unwrap();
}
📋 十一、完整示例:Web 服务器线程池
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::time::Duration;
struct ThreadPool {
workers: Vec<Worker>,
sender: Option<std::sync::mpsc::Sender<Job>>,
}
type Job = Box<dyn FnOnce() + Send + 'static>;
impl ThreadPool {
fn new(size: usize) -> ThreadPool {
let (sender, receiver) = std::sync::mpsc::channel();
let receiver = std::sync::Arc::new(std::sync::Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
for id in 0..size {
workers.push(Worker::new(id, std::sync::Arc::clone(&receiver)));
}
ThreadPool {
workers,
sender: Some(sender),
}
}
fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.sender.as_ref().unwrap().send(job).unwrap();
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
drop(self.sender.take());
for worker in &mut self.workers {
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
}
}
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(id: usize, receiver: std::sync::Arc<std::sync::Mutex<std::sync::mpsc::Receiver<Job>>>) -> Worker {
let thread = thread::spawn(move || loop {
let message = receiver.lock().unwrap().recv();
match message {
Ok(job) => {
println!("Worker {} 开始处理请求", id);
job();
}
Err(_) => {
println!("Worker {} 关闭", id);
break;
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}
fn handle_connection(mut stream: TcpStream) {
let mut buffer = [0; 1024];
stream.read(&mut buffer).unwrap();
let get = b"GET / HTTP/1.1\r\n";
let sleep = b"GET /sleep HTTP/1.1\r\n";
let (status_line, contents) = if buffer.starts_with(get) {
("HTTP/1.1 200 OK", "<h1>你好!</h1>")
} else if buffer.starts_with(sleep) {
thread::sleep(Duration::from_secs(5));
("HTTP/1.1 200 OK", "<h1>睡醒了!</h1>")
} else {
("HTTP/1.1 404 NOT FOUND", "<h1>页面未找到</h1>")
};
let response = format!(
"{}\r\nContent-Length: {}\r\n\r\n{}",
status_line,
contents.len(),
contents
);
stream.write_all(response.as_bytes()).unwrap();
stream.flush().unwrap();
}
fn main() {
let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
let pool = ThreadPool::new(4);
for stream in listener.incoming() {
let stream = stream.unwrap();
pool.execute(|| {
handle_connection(stream);
});
}
println!("服务器关闭");
}
🎯 十二、最佳实践总结
1. 选择正确的工具
| 场景 | 推荐工具 | 原因 |
|---|---|---|
| 任务并行 | rayon |
简单、自动并行化 |
| I/O 密集型 | tokio |
异步、高并发 |
| CPU 密集型 | 原生线程 | 控制更精细 |
| 简单通信 | std::sync::mpsc |
标准库、简单 |
| 复杂通信 | crossbeam |
功能更丰富 |
| 共享状态 | Arc<Mutex<T>> |
标准、安全 |
| 无锁操作 | 原子类型 | 性能更高 |
2. 避免常见陷阱
fn avoid_pitfalls() {
// ❌ 陷阱1:过度使用线程
// 不要为每个小任务创建线程
// ❌ 陷阱2:忘记 Join
// thread::spawn(|| { ... }); // 忘记保存句柄
// ✅ 正确做法
let handle = thread::spawn(|| { /* ... */ });
handle.join().unwrap();
// ❌ 陷阱3:共享可变状态无保护
// let data = vec![1, 2, 3];
// thread::spawn(|| data.push(4)); // 编译错误,Rust 会阻止
// ✅ 正确做法
use std::sync::{Arc, Mutex};
let data = Arc::new(Mutex::new(vec![1, 2, 3]));
let data_clone = Arc::clone(&data);
thread::spawn(move || {
data_clone.lock().unwrap().push(4);
});
}
3. 性能模式
fn performance_patterns() {
// 1. 工作窃取模式
use rayon;
let result: Vec<_> = (0..1000)
.into_par_iter() // 使用 Rayon 并行迭代
.map(|x| x * 2)
.collect();
// 2. 分治模式
fn parallel_sum(data: &[i32]) -> i32 {
if data.len() <= 1000 {
data.iter().sum()
} else {
let mid = data.len() / 2;
let (left, right) = data.split_at(mid);
let (left_sum, right_sum) = rayon::join(
|| parallel_sum(left),
|| parallel_sum(right),
);
left_sum + right_sum
}
}
// 3. 流水线模式
use crossbeam::channel;
let (stage1_tx, stage1_rx) = channel::unbounded();
let (stage2_tx, stage2_rx) = channel::unbounded();
// 第一阶段
thread::spawn(move || {
for i in 0..100 {
stage1_tx.send(i).unwrap();
}
});
// 第二阶段
thread::spawn(move || {
for i in stage1_rx {
stage2_tx.send(i * 2).unwrap();
}
});
// 第三阶段
thread::spawn(move || {
for i in stage2_rx {
println!("结果: {}", i);
}
});
}
📦 十三、Cargo.toml 依赖
[package]
name = "rust_concurrency"
version = "1.0.0"
edition = "2021"
[dependencies]
# 异步运行时
tokio = { version = "1.0", features = ["full"], optional = true }
# 并行数据处理
rayon = { version = "1.5", optional = true }
# 无锁数据结构
crossbeam = { version = "0.8", optional = true }
# 性能测量
criterion = { version = "0.3", optional = true }
[features]
default = []
full = ["tokio", "rayon", "crossbeam"]
async = ["tokio"]
parallel = ["rayon"]
lock-free = ["crossbeam"]
bench = ["criterion"]
[dev-dependencies]
criterion = "0.3"
💡 总结要点
- Rust 保证线程安全:编译时防止数据竞争
- 多种并发模型:线程、异步、无锁、数据并行
- 选择合适的工具:根据任务类型选择最佳方案
- 避免共享可变状态:尽量使用消息传递
- 性能优化:减少锁竞争,使用适当的数据结构
- 错误处理:正确处理锁中毒和线程崩溃
Rust 的多线程编程虽然学习曲线较陡,但一旦掌握,就能编写出既安全又高效的并发程序。编译器会帮你避免大部分常见的并发错误,让你可以专注于业务逻辑。