stateMachine.hpp
Go to the documentation of this file.
1 #ifndef STATE_MACHINE_HPP
2 #define STATE_MACHINE_HPP
3 
4 #include <cassert>
5 #include <chrono>
6 #include <cstdint>
7 #include <functional>
8 #include <iostream>
9 #include <limits>
10 #include <memory>
11 #include <mutex>
12 #include <queue>
13 #include <stdexcept>
14 #include <unordered_map>
15 #include <vector>
16 
17 #include "bufArray.hpp"
18 #include "common.hpp"
19 #include "exceptions.hpp"
20 #include "spinlock.hpp"
21 
66 template <class Identifier, class Packet> class StateMachine {
67 private:
68  using ConnectionID = typename Identifier::ConnectionID;
69  using Hasher = typename Identifier::Hasher;
70  static constexpr auto timeoutIDInvalid = std::numeric_limits<uint32_t>::max();
71 
72 public:
73  /*
74  * XXX -------------------------------------------- XXX
75  * Public declarations
76  * XXX -------------------------------------------- XXX
77  */
78 
79  struct State;
80  class FunIface;
81 
83  using stateFun = std::function<void(State &, Packet *, FunIface &)>;
84 
86  using timeoutFun = std::function<void(State &, FunIface &)>;
87 
89  static constexpr auto StateIDInvalid = std::numeric_limits<StateID>::max();
90 
96  struct State {
97  public:
98  void *stateData;
100  uint32_t timeoutID;
101 
102  State() : stateData(nullptr), state(StateIDInvalid), timeoutID(timeoutIDInvalid){};
104  : stateData(stateData), state(state), timeoutID(timeoutIDInvalid){};
105  State(const State &s)
107 
108  void set(const State &s) {
109  stateData = s.stateData;
110  state = s.state;
111  timeoutID = s.timeoutID;
112  }
113  };
114 
115  /*
116  * XXX -------------------------------------------- XXX
117  * Interface exposed to state functions
118  * XXX -------------------------------------------- XXX
119  */
120 
123  class FunIface {
124  private:
125  friend class StateMachine<Identifier, Packet>;
126 
128  uint32_t pktIdx;
129  BufArray<Packet> &pktsBA;
130  ConnectionID &cID;
131  State &state;
132  bool sendPkt;
133  bool immediateTransition;
134 
135  // Private -> nobody can misuse any FunIface objects
136  FunIface(StateMachine<Identifier, Packet> *sm, uint32_t pktIdx,
137  BufArray<Packet> &pktsBA, ConnectionID &cID, State &state)
138  : sm(sm), pktIdx(pktIdx), pktsBA(pktsBA), cID(cID), state(state), sendPkt(true),
139  immediateTransition(false){};
140 
141  public:
143  if (!sendPkt) {
144  pktsBA.markDropPkt(pktIdx);
145  }
146 
147  // We are not in the endstate - checked by runPkt
148 
149  if (immediateTransition) {
150  /*
151  auto sfIt = sm->functions.find(state.state);
152  if (sfIt == sm->functions.end()) {
153  D(
154  std::cout
155  << "FunIface::~FunIface() Didn't find a function for this state"
156  << std::endl;)
157  // Don't throw, we are in a destructor
158  // throw std::runtime_error("FunIface::~FunIface() No such function
159  // found");
160  }
161 
162  D(std::cout << "Running Function" << std::endl;)
163  (sfIt->second)(state, pktsBA[pktIdx], *this);
164  */
165 
166  assert((sm->functions.size() - 1) >= state.state);
167  auto fun = sm->functions[state.state];
168  assert(fun != nullptr);
169 
170  D(std::cout << "Running Function" << std::endl;)
171  fun(state, pktsBA[pktIdx], *this);
172 
173  if (state.state == sm->endStateID) {
174  D(std::cout << "Reached endStateID - deleting connection" << std::endl;)
175  sm->removeState(cID);
176  }
177  }
178  }
179 
181  void freePkt() { sendPkt = false; }
182 
187  Packet *getPkt() {
188  throw new std::runtime_error("StateMachine::FunIface::getPkt() not implemented");
189  }
190 
198  void transition(StateID newState) { state.state = newState; }
199 
210  void transitionNow(StateID newState) {
211  state.state = newState;
212  this->immediateTransition = true;
213  }
214 
219  void setTimeout(std::chrono::milliseconds timeout, timeoutFun fun) {
220  std::chrono::time_point<std::chrono::steady_clock> now =
221  std::chrono::steady_clock::now();
222  std::chrono::time_point<std::chrono::steady_clock> then = now + timeout;
223 
224  struct Timeout t;
225  t.time = then;
226  t.timeoutID = sm->curTimeoutID++;
227 
228  state.timeoutID = t.timeoutID;
229 
230  sm->timeoutsQ.push(t);
231 
232  // This is an alternative to emplace
233  // emplace should be better
234  /*
235  auto td = std::make_unique<struct TimeoutData>(cID, fun);
236  sm->timeoutFunctions.insert(
237  std::pair<uint32_t, std::unique_ptr<struct TimeoutData>>(
238  t.timeoutID, std::move(td)));
239  */
240 
241  sm->timeoutFunctions.emplace(
242  t.timeoutID, std::make_unique<struct TimeoutData>(cID, fun));
243  }
244  };
245 
247  private:
248  const unsigned int numBuckets = 8;
249  const uint64_t bucketMask = 0b111;
250 
251  std::unordered_map<ConnectionID, State, Hasher> **newStates;
252  SpinLockCLSize **newStatesLock;
253 
254  public:
256  newStates = reinterpret_cast<std::unordered_map<ConnectionID, State, Hasher> **>(
257  malloc(sizeof(void *) * numBuckets));
258  for (unsigned int i = 0; i < numBuckets; i++) {
259  newStates[i] = new std::unordered_map<ConnectionID, State, Hasher>();
260  }
261 
262  newStatesLock =
263  reinterpret_cast<SpinLockCLSize **>(malloc(sizeof(void *) * numBuckets));
264  for (unsigned int i = 0; i < numBuckets; i++) {
265  newStatesLock[i] = new SpinLockCLSize();
266  }
267  };
268 
270  for (unsigned int i = 0; i < numBuckets; i++) {
271  delete (newStates[i]);
272  }
273  free(newStates);
274 
275  for (unsigned int i = 0; i < numBuckets; i++) {
276  newStatesLock[i]->lock();
277  delete (newStatesLock[i]);
278  }
279  free(newStatesLock);
280  }
281 
287  void add(ConnectionID &cID, State &st) {
288  Hasher hasher;
289  uint64_t bucket = hasher(cID) & bucketMask;
290 
291  std::lock_guard<SpinLockCLSize> lock(*(newStatesLock[bucket]));
292  auto insPair = std::pair<ConnectionID, State>(cID, st);
293  newStates[bucket]->insert(insPair);
294  };
295 
306  bool findAndErase(ConnectionID &cID, State *st) {
307  Hasher hasher;
308  uint64_t bucket = hasher(cID) & bucketMask;
309 
310  std::lock_guard<SpinLockCLSize> lock(*(newStatesLock[bucket]));
311  auto ret = newStates[bucket]->find(cID);
312 
313  if (ret != newStates[bucket]->end()) {
314  st->set(ret->second);
315  newStates[bucket]->erase(ret);
316  return true;
317  } else {
318  return false;
319  }
320  };
321 
322  // TODO I should really do this sooner than later
323 #if 0
324  // If this is needed later, I will add it later
329  void erase(ConnectionID &cID);
330 #endif
331  };
332 
333 private:
334  /*
335  * XXX -------------------------------------------- XXX
336  * Private attributes
337  * XXX -------------------------------------------- XXX
338  */
339 
340  // This is the heart of the state tracking
341  // stateTable holds the link between connections and states
342  std::unordered_map<ConnectionID, State, Hasher> stateTable;
343 
344  // This table specifies which function should be called for a packet
345  // belonging to a connection in a specific state
346  std::vector<stateFun> functions;
347 
348  // TODO: Since the identifier should not track anything, maybe we don't
349  // even need a member -> all static functions
350  Identifier identifier;
351 
352  // The state a newly received connection starts in
353  // This is only useful, if listenToConnections is true
354  StateID startStateID;
355 
356  // This function gets called on newly received connections
357  // This is only useful, if listenToConnections is true
358  std::function<void *(ConnectionID)> startStateFun;
359 
360  // If a connection reaches this state, it gets destryed
361  StateID endStateID;
362 
363  // Callback to aquire new packets
364  std::function<Packet *()> getPktCB;
365 
366  // Basically: Server mode or client mode
367  bool listenToConnections;
368 
369  /*
370  * XXX -------------------------------------------- XXX
371  * Timeout handling
372  * XXX -------------------------------------------- XXX
373  */
374 
375  // This represents the timeout tracking
376  struct Timeout {
377  // This is the, when the timeout ticks out
378  std::chrono::time_point<std::chrono::steady_clock> time;
379 
380  // This uniquely identifies a single timeout
381  uint32_t timeoutID;
382 
383  // Make sure, that the nearest timeout is the first one in timeoutsQ (see below)
384  class Compare {
385  public:
386  bool operator()(struct Timeout a, struct Timeout b) { return a.time < b.time; };
387  };
388  };
389 
390  // This increments for every timeout -> provides a unique ID
391  uint32_t curTimeoutID;
392 
393  // This contains all the timeouts
394  std::priority_queue<struct Timeout, std::vector<struct Timeout>,
395  typename Timeout::Compare>
396  timeoutsQ;
397 
398  // This is used only to check if a timeout is valid, and given that, to
399  // execute it
400  struct TimeoutData {
401  // This is used to identify, which connection the timeout belongs to
402  // indexes the stateTable
403  ConnectionID id;
404 
405  // This function is executed, if the timeout ticks out
406  timeoutFun fun;
407 
408  // Trivial constructor
409  TimeoutData(ConnectionID id, timeoutFun fun) : id(id), fun(fun){};
410  };
411 
412  // If a timeoutID maps to a TimeoutData, this timeout is valid
413  // if it maps to ::end(), then it is ignored.
414  // The unique_ptr should make sure, that the map is not blown up by large ConnectionIDs
415  // XXX The indirection assumes, that timeouts don't happen frequently, otherwise
416  // it may be better to store them directly in the map? (not sure)
417  std::unordered_map<uint32_t, std::unique_ptr<struct TimeoutData>> timeoutFunctions;
418 
419  /*
420  * XXX -------------------------------------------- XXX
421  * Connection sharing
422  * XXX -------------------------------------------- XXX
423  */
424 
425  // These two members allow to open a connection on one core,
426  // and receive subsequent packets on another
427  static ConnectionPool connPoolStatic;
428  ConnectionPool *connPool;
429 
430  /*
431  * XXX -------------------------------------------- XXX
432  * Statistics
433  * XXX -------------------------------------------- XXX
434  */
435 
436  uint64_t stat_statesAdded = 0;
437  uint64_t stat_statesClosed = 0;
438 
439  /*
440  * XXX -------------------------------------------- XXX
441  * Private helper methods
442  * XXX -------------------------------------------- XXX
443  */
444 
445  auto findState(ConnectionID id) {
446  findStateLoop:
447  auto stateIt = stateTable.find(id);
448  if (stateIt == stateTable.end()) {
449 
450  // Try to find state in the connection pool
451  {
452  State st;
453  if (connPool->findAndErase(id, &st)) {
454  D(std::cout << "StateMachine::findState() found state in connPool"
455  << std::endl;)
456  stateTable.insert({id, st});
457 
458  stat_statesAdded++;
459 
460  goto findStateLoop;
461  }
462  }
463 
464  // Maybe accept the new connection
465  if (listenToConnections) {
466  // Add new state
467  D(std::cout << "Adding new state" << std::endl;)
468  D(std::cout << "ConnectionID: " << static_cast<std::string>(id) << std::endl;)
469 
470  // Create startState data object
471  void *stateData = nullptr;
472  if (startStateFun) {
473  stateData = startStateFun(id);
474  }
475 
476  State s(startStateID, stateData);
477  stateTable.insert({id, s});
478 
479  stat_statesAdded++;
480 
481  goto findStateLoop;
482  }
483 
484  } else {
485  D(std::cout << "State found" << std::endl;)
486  }
487  return stateIt;
488  };
489 
490  void runPkt(BufArray<Packet> &pktsIn, unsigned int cur) {
491  D(std::cout << std::endl << "StateMachine::runPkt() called" << std::endl;)
492 
493  try {
494  // Retrieve the current packet
495  Packet *pktIn = pktsIn[cur];
496 
497  // Try to identify the inbound packet
498  ConnectionID identity = identifier.identify(pktIn);
499 
500  // Find a state/connection associated with this packet
501  auto stateIt = findState(identity);
502 
503  if (stateIt == stateTable.end()) {
504  // We don't want this packet
505  D(std::cout << "StateMachine::runPkt() discarding packet" << std::endl;)
506  D(std::cout << "ident of packet: " << static_cast<std::string>(identity)
507  << std::endl;)
508  return;
509  }
510 
511  // Invalidate any previous timeouts
512  if (stateIt->second.timeoutID != timeoutIDInvalid) {
513  this->timeoutFunctions.erase(stateIt->second.timeoutID);
514  stateIt->second.timeoutID = timeoutIDInvalid;
515  }
516 
517  // Try to retrieve an appropriate function
518  /*
519  auto sfIt = functions.find(stateIt->second.state);
520  if (sfIt == functions.end()) {
521  D(std::cout << "StateMachine::runPkt() Didn't find a function for this state"
522  << std::endl;)
523  throw std::runtime_error("StateMachine::runPkt() No such function found");
524  }
525  */
526  assert((functions.size() - 1) >= stateIt->second.state);
527  auto fun = functions[stateIt->second.state];
528  assert(fun != nullptr);
529 
530  // Create the custom function interface
531  FunIface funIface(this, cur, pktsIn, identity, stateIt->second);
532 
533  // Run the function
534  D(std::cout << "StateMachine::runPkt() Running Function" << std::endl;)
535  D(std::cout << "StateMachine::runPkt() identity: "
536  << static_cast<std::string>(identity) << std::endl;)
537  D(std::cout << "StateMachine::runPkt() hexdump of packet: " << std::endl;)
538  D(hexdump(pktIn->getData(), pktIn->getDataLen());)
539  //(sfIt->second)(stateIt->second, pktIn, funIface);
540  fun(stateIt->second, pktIn, funIface);
541 
542  // Check if the endstate is reached
543  if (stateIt->second.state == endStateID) {
544  D(std::cout
545  << "StateMachine::runPkt() Reached endStateID - deleting connection"
546  << std::endl;)
547  removeState(identity);
548 
549  stat_statesClosed++;
550  }
551 
552  // At this point, the funIface is destroyed, and it is checked, if
553  // any transitionNow calls were made.
554  // XXX Therefore: DO NOT WRITE ANY CODE BELOW THIS COMMENT
555  // (or at least give it a hard thought)
556 
557  } catch (PacketNotIdentified *e) {
558  D(std::cout << "StateMachine::runPkt() Packet could not be identified"
559  << std::endl;);
560  pktsIn.markDropPkt(cur);
561  }
562  }
563 
564 public:
565  /*
566  * XXX -------------------------------------------- XXX
567  * Public interface
568  * XXX -------------------------------------------- XXX
569  */
570 
572  : startStateID(0), endStateID(StateIDInvalid), listenToConnections(false),
573  curTimeoutID(0), connPool(&connPoolStatic){};
574 
576  std::cout << "StateMachine stats:" << std::endl;
577  std::cout << "stateTable.size() = " << stateTable.size() << std::endl;
578  std::cout << "statesAdded = " << stat_statesAdded << std::endl;
579  std::cout << "statesClosed = " << stat_statesClosed << std::endl;
580  }
581 
587  size_t getStateTableSize() { return stateTable.size(); };
588 
597  void registerFunction(StateID id, stateFun function) {
598  // functions.insert({id, function});
599  if ((static_cast<int>(functions.size()) - 1) < static_cast<int>(id)) {
600  functions.resize(id + 1);
601  }
602  functions[id] = function;
603  }
604 
610  void registerEndStateID(StateID endStateID) { this->endStateID = endStateID; }
611 
622  StateID startStateID, std::function<void *(ConnectionID)> startStateFun) {
623  this->startStateID = startStateID;
624  this->startStateFun = startStateFun;
625  listenToConnections = true;
626  }
627 
632  void registerGetPktCB(std::function<Packet *()> fun) { getPktCB = fun; }
633 
634  void setConnectionPool(ConnectionPool *cp) { connPool = cp; }
635 
642  void removeState(ConnectionID id) { stateTable.erase(id); }
643 
652  void addState(ConnectionID id, State st, BufArray<Packet> &pktsIn) {
653 
654  /*
655  auto sfIt = functions.find(st.state);
656  if (sfIt == functions.end()) {
657  throw std::runtime_error("StateMachine::addState() No such function found");
658  }
659  */
660 
661  assert((functions.size() - 1) >= st.state);
662  auto fun = functions[st.state];
663  assert(fun != nullptr);
664 
665  FunIface funIface(this, 0, pktsIn, id, st);
666 
667  D(std::cout << "StateMachine::addState() Running Function" << std::endl;)
668  //(sfIt->second)(st, pktsIn[0], funIface);
669  fun(st, pktsIn[0], funIface);
670 
671  if (st.state == endStateID) {
672  D(std::cout << "StateMachine::addState() Reached endStateID - deleting connection"
673  << std::endl;)
674  return;
675  }
676 
677  D(std::cout << "StateMachine::addState() adding connection to newStates"
678  << std::endl;)
679  D(std::cout << "StateMachine::addState() identity: " << static_cast<std::string>(id)
680  << std::endl;)
681  connPool->add(id, st);
682  }
683 
693  uint32_t inCount = pktsIn.getTotalCount();
694 
695  D(std::cout << std::endl
696  << "StateMachine::runPktBatch() running incoming batch now, #Pkts: "
697  << inCount << std::endl;)
698 
699  // This loop handles the timeouts
700  // It breaks, if there are no usable timeouts anymore
701  // It will (usually) not run until timeoutsQ is empty
702  while (!timeoutsQ.empty()) {
703  // Get the next timeout
704  auto timeoutElem = timeoutsQ.top();
705 
706  // If the current timeout is in the future -> break
707  if (timeoutElem.time > std::chrono::steady_clock::now()) {
708  break;
709  }
710 
711  // Extract some info from the timeout
712  uint32_t timeoutID = timeoutElem.timeoutID;
713  auto timeoutDataIt = timeoutFunctions.find(timeoutID);
714 
715  // Check if the timeout is valid
716  if (timeoutDataIt == timeoutFunctions.end()) {
717  timeoutsQ.pop();
718  continue;
719  }
720 
721  // Prepare function call
722  std::unique_ptr<struct TimeoutData> timeoutData =
723  std::move(timeoutDataIt->second);
724  auto stateIt = findState(timeoutData->id);
725  assert(stateIt != stateTable.end());
726  FunIface funIface(this, std::numeric_limits<uint32_t>::max(), pktsIn,
727  timeoutData->id, stateIt->second);
728 
729  // Clear the timeoutID from the state
730  stateIt->second.timeoutID = timeoutIDInvalid;
731 
732  // Call function
733  timeoutData->fun(stateIt->second, funIface);
734 
735  // Check if we reached the end state
736  if (stateIt->second.state == endStateID) {
737  D(std::cout << "Reached endStateID - deleting connection" << std::endl;)
738  removeState(timeoutData->id);
739  }
740 
741  // Clean up
742  timeoutsQ.pop();
743  timeoutFunctions.erase(timeoutDataIt);
744  }
745 
746  // Run all the usual incoming packets
747  for (uint32_t i = 0; i < inCount; i++) {
748  runPkt(pktsIn, i);
749  }
750  }
751 };
752 
753 // Define static members of the state machine
754 
755 // Don't try to understand the template stuff, it works...
756 template <class Identifier, class Packet>
759 
760 /*
761 template <class Identifier, class Packet>
762 const uint64_t StateMachine<Identifier, Packet>::ConnectionPool::bucketMask;
763 
764 template <class Identifier, class Packet>
765 const unsigned int StateMachine<Identifier, Packet>::ConnectionPool::numBuckets;
766 */
767 
768 #endif /* STATE_MACHINE_HPP */
void runPktBatch(BufArray< Packet > &pktsIn)
Run a batch of packets.
State machine framework.
void removeState(ConnectionID id)
Remove a connection.
void set(const State &s)
Main interface for the needs of a state function.
Packet * getPkt()
Get an additional packet buffer.
void markDropPkt(uint32_t pktIdx)
Mark one packet as drop.
Definition: bufArray.hpp:87
void transitionNow(StateID newState)
Immediately transition to another state.
std::function< void(State &, mbuf *, FunIface &)> stateFun
This is the signature any state function needs to expose.
State(StateID state, void *stateData)
void hexdump(const void *data, int dataLen)
Dump hex data.
Definition: hexdump.cpp:12
uint32_t getTotalCount() const
Get the number of all packets in the BufArray.
Definition: bufArray.hpp:211
#define D(x)
Definition: common.hpp:10
void registerGetPktCB(std::function< Packet *()> fun)
Register a callback in order to get new buffer.
bool findAndErase(ConnectionID &cID, State *st)
Try to find a state for a given connection ID.
void freePkt()
Free the packet after the batch is processed, do not send it.
bool operator()(struct Timeout a, struct Timeout b)
void addState(ConnectionID id, State st, BufArray< Packet > &pktsIn)
Open an outgoing connection.
void registerStartStateID(StateID startStateID, std::function< void *(ConnectionID)> startStateFun)
This method describes, how to proceed with incoming connections.
void add(ConnectionID &cID, State &st)
Add connection and state to the connection pool.
static constexpr auto StateIDInvalid
Represents an invalid StateID.
void setConnectionPool(ConnectionPool *cp)
uint16_t StateID
Definition: common.hpp:19
void registerFunction(StateID id, stateFun function)
Register a function for a given state.
std::function< void(State &, FunIface &)> timeoutFun
This is the signature any timeout function needs to expose.
size_t getStateTableSize()
Get the number of tracked connections This is probably only for statistics.
Represents one connection.
void setTimeout(std::chrono::milliseconds timeout, timeoutFun fun)
Set a timeout, after which a transition will happen.
Wrapper around MoonGen bufarrays.
Definition: bufArray.hpp:42
State(const State &s)
void registerEndStateID(StateID endStateID)
This registers an end state If a connections reaches this id, it will be destroyed.
void transition(StateID newState)
Transition to another state.