5/16/2015

05-16-15 - Threading Primitive - monitored semaphore

A monitored semaphore allows two-sided waiting :

The consumer side decs the semaphore, and waits on the count being positive.

The producer side incs the semaphore, and can wait on the count being a certain negative value (some number of waiting consumers).

Monitored semaphore solves a specific common problem :

In a worker thread system, you may need to wait on all work being done. This is hard to do in a race-free way using normal primitives. Typical ad-hoc solutions may miss work that is pushed during the wait-for-all-done phase. This is hard to enforce, ugly, and makes bugs. (it's particularly bad when work items may spawn new work items).

I've heard of many ad-hoc hacky ways of dealing with this. There's no need to muck around with that, because there's a simple and efficient way to just get it right.

The monitored semaphore also provides a race-free way to snapshot the state of the work system - how many work items are available, how many workers are sleeping. This allows you to wait on the joint condition - all workers are sleeping AND there is no work available. Any check of those two using separate primitives is likely a race.

The implementation is similar to the fastsemaphore I posted before.

"fastsemaphore" wraps some kind of underlying semaphore which actually provides the OS waits. The underlying semaphore is only used when the count goes negative. When count is positive, pops are done with simple atomic ops to avoid OS calls. eg. we only do an OS call when there's a possibility it will put our thread to sleep or wake a thread.

"fastsemaphore_monitored" uses the same kind atomic variable wrapping an underlying semaphore, but adds an eventcount for the waiter side to be triggered when enough workers are waiting. (see who ordered event count? )

Usage is like this :


To push a work item :

push item on your queue (MPMC FIFO or whatever)
fastsemaphore_monitored.post();

To pop a work item :

fastsemaphore_monitored.wait();
pop item from queue

To flush all work :

fastsemaphore_monitored.wait_for_waiters(num_worker_threads);

NOTE : in my implementation, post & wait can be called from any thread, but wait_for_waiters must be called from only one thread. This assumes you either have a "main thread" that does that wait, or that you wrap that call with a mutex.

template <typename t_base_sem>
class fastsemaphore_monitored
{
    atomic<S32> m_state;
    eventcount m_waiters_ec;
    t_base_sem m_sem;

    enum { FSM_COUNT_SHIFT = 8 };
    enum { FSM_COUNT_MASK = 0xFFFFFF00UL };
    enum { FSM_COUNT_MAX = ((U32)FSM_COUNT_MASK>>FSM_COUNT_SHIFT) };
    enum { FSM_WAIT_FOR_SHIFT = 0 };
    enum { FSM_WAIT_FOR_MASK = 0xFF };
    enum { FSM_WAIT_FOR_MAX = (FSM_WAIT_FOR_MASK>>FSM_WAIT_FOR_SHIFT) };

public:
    fastsemaphore_monitored(S32 count = 0)
    :   m_state(count<<FSM_COUNT_SHIFT)
    {
        RL_ASSERT(count >= 0);
    }

    ~fastsemaphore_monitored()
    {
    }

public:

    inline S32 state_fetch_add_count(S32 inc)
    {
        S32 prev = m_state($).fetch_add(inc<<FSM_COUNT_SHIFT,mo_acq_rel);
        S32 count = ( prev >> FSM_COUNT_SHIFT );
        RR_ASSERT( count < 0 || ( (U32)count < (FSM_COUNT_MAX-2) ) );
        return count;
    }

    // warning : wait_for_waiters can only be called from one thread!
    void wait_for_waiters(S32 wait_for_count)
    {
        RL_ASSERT( wait_for_count > 0 && wait_for_count < FSM_WAIT_FOR_MAX );
        
        S32 state = m_state($).load(mo_acquire);
        
        for(;;)
        {
            S32 cur_count = state >> FSM_COUNT_SHIFT;

            if ( (-cur_count) == wait_for_count )
                break; // got it
        
            S32 new_state = (cur_count<<FSM_COUNT_SHIFT) | (wait_for_count << FSM_WAIT_FOR_SHIFT);
            
            S32 ec = m_waiters_ec.prepare_wait();
            
            // double check and signal what we're waiting for :
            if ( ! m_state.compare_exchange_strong(state,new_state,mo_acq_rel) )
                continue; // retry ; state was reloaded
            
            m_waiters_ec.wait(ec);
            
            state = m_state($).load(mo_acquire);
        }
        
        // now turn off the mask :
        
        for(;;)
        {
            S32 new_state = state & FSM_COUNT_MASK;
            if ( state == new_state ) return;
        
            if ( m_state.compare_exchange_strong(state,new_state,mo_acq_rel) )
                return; 
                
            // retry ; state was reloaded
        }
    }

    void post()
    {
        if ( state_fetch_add_count(1) < 0 )
        {
            m_sem.post();
        }
    }

    void wait_no_spin()
    {
        S32 prev_state = m_state($).fetch_add((-1)<<FSM_COUNT_SHIFT,mo_acq_rel);
        S32 prev_count = prev_state>>FSM_COUNT_SHIFT;
        if ( prev_count <= 0 )
        {
            S32 waiters = (-prev_count) + 1;
            RR_ASSERT( waiters >= 1 );
            S32 wait_for = prev_state & FSM_WAIT_FOR_MASK;
            if ( waiters == wait_for )
            {
                RR_ASSERT( wait_for >= 1 );
                m_waiters_ec.notify_all();
            }
            
            m_sem.wait();
        }
    }
    
    void post(S32 n)
    {
        RR_ASSERT( n > 0 );
        for(S32 i=0;i<n;i++)
            post();
    }
       
    bool try_wait()
    {
        // see if we can dec count before preparing the wait
        S32 state = m_state($).load(mo_acquire);
        for(;;)
        {
            if ( state < (1<<FSM_COUNT_SHIFT) ) return false;
            // dec count and leave the rest the same :
            //S32 new_state = ((c-1)<<FSM_COUNT_SHIFT) | (state & FSM_WAIT_FOR_MASK);
            S32 new_state = state - (1<<FSM_COUNT_SHIFT);
            RR_ASSERT( (new_state>>FSM_COUNT_SHIFT) >= 0 );
            if ( m_state($).compare_exchange_strong(state,new_state,mo_acq_rel) )
                return true;
            // state was reloaded
            // loop
            // backoff here optional
        }
    }
     
       
    S32 try_wait_all()
    {
        // see if we can dec count before preparing the wait
        S32 state = m_state($).load(mo_acquire);
        for(;;)
        {
            S32 count = state >> FSM_COUNT_SHIFT;
            if ( count <= 0 ) return 0;
            // swap count to zero and leave the rest the same :
            S32 new_state = state & FSM_WAIT_FOR_MASK;
            if ( m_state($).compare_exchange_strong(state,new_state,mo_acq_rel) )
                return count;
            // state was reloaded
            // loop
            // backoff here optional
        }
    }
           
    void wait()
    {
        int spin_count = rrGetSpinCount();
        while(spin_count--)
        {
            if ( try_wait() ) 
                return;
        }
        
        wait_no_spin();
    }

};

No comments:

old rants