A simple ATM implementation based on message queue

对《C++ Concurrency In Action》第4.4.2部分的一个记录,主要介绍了如何使用消息队列来实现线程间的交互。这种做法叫做Communicating Sequential Processes,简称CSP,其思路就是如果线程间没有共享的数据那么分析起来就会简单很多,我们只需要考虑每个线程在收到特定的消息时的行为即可,每个线程可以视作一个有限状态机

ATM机的实现思路

现在思考一下ATM机的工作流程,初始状态什么都没有,当插入银行卡后会进入输入密码的环节,此时可以输入密码、回退或是取消。当输入6个数字后进行密码校验,如果成功则进入账户页面。账户页面可以取款、查询余额,取款时需要输入金额并向银行确认交易结果,成功则机器取出对应的钞票,否则提示失败。总体上来看我们可以分成三个独立的线程:

  1. 逻辑线程:处理用户输入、状态流转
  2. 硬件线程:展示对应信息、提取现金
  3. 银行线程:校验密码和余额情况

流程图

Image

具体实现

首先实现一个消息队列的模型,支持任意类型的消息

Message类

// 基类
struct MessageBase {
public:
  virtual ~MessageBase() {}
};

// 包装类,模板参数指定消息类型
template<typename Message>
struct WrapperMessage : public MessageBase {
public:
  explicit WrapperMessage(const Message& content) : content_(content) {}
  Message content_;
};

MessageQueue类

// 使用互斥锁和条件变量管理的消息队列
class MessageQueue {
public:
  template<typename Message>
  void Push(const Message& message) {
    std::lock_guard<std::mutex> lg(m_);
    queue_.push(std::make_shared<WrapperMessage<Message>>(message));
    cv_.notify_all();
  }

  std::shared_ptr<MessageBase> WaitAndPop() {
    std::unique_lock<std::mutex> l(m_);
    cv_.wait([&]{
        return !queue_.empty();
    });
    auto f = queue_.front();
    queue_.pop();
    return f;
  }

private:
  std::queue<std::shared_ptr<MessageBase>> queue_;
  std::condition_variable cv_;
  std::mutex m_;
};

Sender和Receiver类

// Sender类引用一个消息队列,并向其推送自定义消息
class Sender {
public:
  explicit Sender(MessageQueue* queue) : queue_(queue) {}
  template<typename Message>
  void Send(const Message& message) {
    if (!queue_) {
        return;
    }
    queue_->Push(message);
  }

private:
  MessageQueue* queue_ = nullptr;
};

// Receiver类持有一个消息队列,创建Dispatcher类来处理其中的消息。支持转化为Sender来利用其中的队列发送消息
class Receiver {
public:
  Dispatcher Wait() {
    return Dispatcher(&queue_);
  }
  operator Sender() {
    return Sender(&queue_);
  }

private:
  MessageQueue queue_;
};

Dispatcher类

Dispatcher类是一个专门负责分发消息的类,当它析构时它会尝试将对应队列中所有的消息分发出去。这其实只是一个兜底操作,大多数情况是通过调用Handle函数来处理特定的消息。注意这里的chained_成员变量用来标记这个Dispatcher是不是已经“链”进去了,主要是避免重复进行分发。在实现中我们将Dispatcher一个个链起来处理消息的时候会看的更清楚 TemplateDispatcherDispatcher类几乎一样,但是增加了处理的Message和Func作为模板参数

struct CloseQueue {};

class Dispatcher {
public:
  explicit Dispatcher(MessageQueue* queue) : queue_(queue) {}
  Dispatcher(Dispatcher&& other) : queue_(other.queue_), chained_(other.chained_) {
    other.chained_ = true;
  }

  ~Dispatcher() noexcept(false) {
    if (chained_) {
        return;
    }
    WaitAndDispatch();
  }

  Dispatcher(const Dispatcher& other) = delete;
  Dispatcher& operator=(const Dispatcher& other) = delete;

  template <typename Dispatcher, typename Msg, typename Func>
  friend class TemplateDispatcher;

  template<typename Message, typename Func>
  TemplateDispatcher<Dispatcher, Message, Func> Handle(Func&& f) {
    return TemplateDispatcher(queue_, this, std::forward<Func>(f));
  }

  bool Dispatch(const std::shared_ptr<MessageBase>& message) {
    if (dynamic_cast<WrapperMessage<CloseQueue>*>(message.get())) {
        throw CloseQueue();
    }
    return false;
  }

private:
  void WaitAndDispatch() {
    if (!queue_) {
        return;
    }
    for (;;) {
        auto message = queue_->WaitAndPop();
        Dispatch(message);
    }
  }

private:
  MessageQueue* queue_ = nullptr;
  bool chained_ = false;
};

template<typename PreviousDispatcher, typename Message, typename Func>
class TemplateDispatcher {
public:
  TemplateDispatcher(MessageQueue* queue, PreviousDispatcher* previous, Func&& f) :
  queue_(queue), chained_(previous->chained_), previous_(previous), f_(std::forward<Func>(f)) {
    previous->chained_ = true;
  }

  TemplateDispatcher(TemplateDispatcher&& other) :
  queue_(other.queue_), chained_(other.chained_), previous_(other.previous_), f_(std::move(other.f_)) {
    other.chained_ = true;
  }

  TemplateDispatcher(const TemplateDispatcher &other) = delete;
  TemplateDispatcher &operator=(const TemplateDispatcher &other) = delete;

  template <typename Dispatcher, typename OtherMsg, typename OtherFunc>
  friend class TemplateDispatcher;
  
  template<typename OtherMessage, typename OtherFunc>
  TemplateDispatcher<TemplateDispatcher, OtherMessage, OtherFunc> Handle(OtherFunc&& f) {
    return TemplateDispatcher(queue_, this, std::forward<OtherFunc>(f));
  }

  ~TemplateDispatcher() noexcept(false) {
    if (chained_) {
        return;
    }
    WaitAndDispatch();
  }

  bool Dispatch(const std::shared_ptr<MessageBase>& message) {
    if (auto* wrapper_message = dynamic_cast<WrapperMessage<Message>*>(message.get())) {
        f_(wrapper_message->content_);
        return true;
    } else {
        return previous_->Dispatch(message);
    }
    return false;
  }

private:
  void WaitAndDispatch() {
    if (!queue_) {
        return;
    }
    for (;;) {
        auto message = queue_->WaitAndPop();
        if (Dispatch(message)) {
            break;
        }
    }
  }

private:
  MessageQueue* queue_ = nullptr;
  bool chained_ = false;
  PreviousDispatcher* previous_ = nullptr;
  Func f_;
};

Dispatcher使用实例:

Receiver receiver;
receiver.Wait().Handle<Message1>([](const Message1&){
  // Do something...
}).Handle<Message2>([](const Message2&){
  // Do something...
});

过程分析:

  • 调用这行代码时会创建若干个DispatcherTemplateDispatcher对象,它们引用同一个MessageQueue
  • 执行完这条语句后,从后向前依次调用每个对象的析构函数,如果队列中没有消息则阻塞在WaitAndDispatch方法,否则取出消息后试图进行分发,如果当前的Dispatcher处理不了(消息类型不匹配)就继续传递给前一个进行处理
  • 如果某条消息经过了处理,就从析构函数的无限循环中退出。由于chained_变量的存在,这条消息最多只会被处理一次

入口代码

在具体实现状态机前,可以先大致写出入口代码

  1. 我们需要三个对象来分别处理主状态机、硬件以及银行的逻辑
  2. 这三个对象分别在三个线程上进行各自的工作
  3. 它们通过消息队列来互相发送消息。例如初始状态状态机线程等待插入银行卡,当插入银行卡后由硬件线程输出一行欢迎语句,接下来主线程接受用户输入,并发送给状态机线程处理相关逻辑,输入完毕后则发送给银行线程进行校验……
...
BankMachine bank_machine(name, pin, balance);
HardwareMachine hardware_machine;
ATMMachine atm_machine(bank_machine.GetSender(),
                        hardware_machine.GetSender());
messaging::Sender sender(atm_machine.GetSender());
std::thread atm_thread(&ATMMachine::Run, &atm_machine);
std::thread bank_thread(&BankMachine::Run, &bank_machine);
std::thread hardware_thread(&HardwareMachine::Run, &hardware_machine);
...

atm_machine.Done();
hardware_machine.Done();
bank_machine.Done();
atm_thread.join();
hardware_thread.join();
bank_thread.join();

剩下的过程其实就是慢慢根据状态流转图去完善这三个类了,具体代码这里就不详细列出了,可以直接在Github仓库查看