Reference counting tutorial

If you're doing multi-threaded programs in C or C++ you really should get a good handle on reference counting. The Linux coding style says it [btw, if you haven't read that, do it now. Go on, I'll wait.]:

Remember: if another thread can find your data structure, and you don't have a reference count on it, you almost certainly have a bug.

The solution I will present is, if not a design pattern (sounds too buzzy), at least good practice. After discovering it, I changed a lot of my code to use it consistently, and found out that makes the code simpler and less prone to synchronizations bugs, memory leaks and slugs. Get it right, and you won't miss the GC too much.

I said a C or C++ application because the languages with embedded Garbage Collectors solve this transparently for you. Not for free, though, but for (many) extra CPU cycles.

Sample Problem

Let's take something simple:

In a C/C++ application, you have a global list [| collection | hashtable | tree] of objects that can be accessed by multiple threads. Each object has an unique identifier. Each thread can add a new object to the list or get an object to modify it or delete it.

Sounds basic, doesn't it? Well, you can really shoot yourself in the foot with it. Special care must be taken when deleting objects, because other threads might currently work with that object or iterate through the global list.

Disclaimers

The method is Not Invented Here (tm). I saw it in the Linux code, to name one, and is popular between experienced programmers. I haven't seen many tutorials about it, though. BTW, this article might be useful for you if you're trying to understand how reference counting is used in the Linux kernel.

The code snippets are in C with the pthread library, but better consider them as pseudo-code. They're meant as starting points, not as ready for copy'n'paste. I have never tried them, so I would be surprised if they even compile cleanly.

Solution

Over time, I found this minimal API to be convenient:

1
2
3
obj_t* obj_lookup(int id);
obj_t* obj_lookup_or_create(int id, int *isnew);
void obj_kill(obj_t *obj);

The obj_lookup() function needs no introduction, it simply searches an object by id in the global list. If not found, returns NULL.

There is no explicit create function because we cannot have two objects with the same id, and the API reflects this. The obj_lookup_or_create() searches for an object with the specified id, and if not found, creates one and returns it. It will return NULL only in case of error while creating. The isnew output parameter is used to let the caller know if the object was created by this operation or not.

The obj_kill() function removes the object from the global list and releases it.

Since I mentioned multi-threaded, we need to think about locking. Having a single Big Lock (for short BFL ;-) ) that protects both the global lists and all the objects is simple. It also doesn't leave room for many bugs, but all operations are serialized and performance drops. I consider this as not an option. [In fact, there's more than a performance issue: I think it's a petty that many GUI applications choose this method and we get unresponsive applications when one thread blocks for a long time with the BFL taken.]

It's better to lock each object individually and have another lock for accessing the global list. Operations on different objects can be made in parallel, so we take advantage of all those cores from the new machines.

The buggy way

I would like to start with a naive implementation, to better identify the issue:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
static list_t *_list;
static pthread_mutex_t *_mutex;


type def obj {
    int id;
    void *value;

    pthread_mutex_t mutex;
} obj_t;


obj_t* lookup_or_create(int id, int *isnew)
{
    obj_t *obj;

    pthread_mutex_lock(&_mutex);
    list_for_each(obj, list) {
        if (obj->id == id) {
            PMUTEX_UNLOCK(&_mutex);
            return obj;
        }
    }

    /* not found, create */
    obj = obj_new(id);
    list_add(list, obj);
    pthread_mutex_unlock(&_mutex)

    return obj;
}

void obj_kill(obj_t *obj)
{
    pthread_mutex_lock(&_mutex);
    list_del(list, obj);
    pthread_mutex_unlock(&_mutex);
}

The obj_lookup() is too similar to the obj_lookup_or_create() to be worth adding it. A possible usage (again, naive) would be:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
int process_obj_with_id(int id)
{
    obj_t *obj;
    int isnew;

    obj = lookup_or_create(id, &isnew);
    if (!obj)
        return -1;

    pthread_mutex_lock(&obj->mutex);
    process_object(obj);
    if (should_be_dead(obj))
        obj_kill(obj);
    pthread_mutex_unlock(&obj->mutex);

    if (obj_is_dead(obj))
        free(obj);

    return 0;
}

This function can be called by multiple threads concurentely, that was the whole point. The process_object() function do the actual job (whatever that is). The should_be_dead() is introduced only to suggest that the user can decide to delete the object at any point.

Can you spot the bug in the above code?

It's easy, consider the following situation: two threads call the process_object_with_id() at the same time with the same id. An object with that id already exists. Both of them execute the lookup and get pointers to the object. One of the threads takes the lock, the other waits. So far, so good. But now, the first thread decides to delete the object and frees its memory. Oops. The second thread now has an invalid pointer to work with and will probably segfault.

You can try to work around it by using properties specific to the application or by locking tricks, but I wouldn't recommend it. All you'll get is buggy code. If you recognize this pattern, you should use

The good way

The good way is to use reference counting. Whenever you get a pointer to an object, increment its reference count. Whenever the pointer gets out of scope or it's removed, decrement it. When the reference count of an object gets to zero, free it. Reference counting can be efficiently implemented with atomic operations:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
typedef struct obj {
    int id;
    void *value;

    pthread_mutex_t mutex;
    atomic_t ref;
    bool_t killed;
} obj_t;


inline void obj_refinc(obj_t *obj)
{
    assert(atomic_read(&obj->ref) > 0);
    atomic_inc(&obj->ref);
}

inline int obj_refdec(obj_t *obj)
{
    assert(obj->killed);
    if (atomic_dec_and_test(&obj->ref)) {
        free(obj);
        return TRUE;
    }
    return FALSE;
}

Things to note so far:

  • The assert from the obj_refinc() function (line 13) is a nice bug trap: you can't increment the reference count of an object if you don't already have a pointer to it. Thus, the reference count must be greater than zero already.
  • The 'killed' boolean for the obj structure is not always needed. It does, however, a good job against double deletion (see the obj_kill() function below), so I usually add it.
  • obj_refdec() returns TRUE if the object was deleted at this operations. This is also not strictly needed but it's useful for bug trapping. If obj_refdec() returns TRUE, you know that accessing it from now on will cause problems.

The rest of the interesting functions can be implemented like this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
obj_t* lookup_or_create(int id, int *isnew)
{
    obj_t *obj;

    pthread_mutex_lock(&_mutex);
    list_for_each(obj, list) {
        if (!obj->killed && obj->id == id) {
            obj_refinc(obj);
            pthread_mutex_unlock(&_mutex);

            return obj;
        }
    }

    /* not found, create */
    obj = obj_new(id); /* sets reference count to 1 */

    list_add(list, obj);
    obj_refinc(obj);

    pthread_mutex_unlock(&_mutex)

    return obj;
}

void obj_kill(obj_t *obj)
{
    if (!obj->killed) {
        obj->killed = TRUE;

        pthread_mutex_lock(&_mutex);
        list_del(list, obj);
        pthread_mutex_unlock(&_mutex)

        if (obj_decref(obj)) {
            assert(0); /* BUG: the ref count got to zero to soon */
        }
    }
}

int process_obj_with_id(int id)
{
    obj_t *obj;
    int isnew;

    obj = lookup_or_create(id, &isnew);
    if (!obj)
        return -1;

    pthread_mutex_lock(&obj->mutex);
    if (!obj->killed) {
        process_object(obj);
        if (should_be_dead(obj))
            obj_kill(obj);
    }
    pthread_mutex_unlock(&obj->mutex);
    obj_refdec(obj);
    return 0;
}

Note that:
* When creating an object (line 16), the reference count is set to 1. This is to reflect the initial pointer, as returned by malloc, that we have on the object.
* When the object is added to the global lists, it's reference count is incremented (line 19).
* Every time a thread gets a pointer of the object from the global list, it increments its reference count (line 8), and every time the pointer gets out of scope, the reference count must be decremented (line 57).
* obj_kill() checks the return code of the obj_refdec() (line 35). It is impossible to have a reference count of zero after this call, because there is at least one more pointer to the object (the *obj from process_obj_with_id, in this case).
* Since after the call to lookup_or_create() (line 46) the reference count of obj is grater than one, the object will not be freed by another thread until we ref dec..
* The 'killed' flag, however, needs to be checked after getting the lock (line 51) because other thread might have set it before we got the lock.
* After we are done with the object (usually when the pointer goes out of scope) we need to call obj_decref() (line 57).

blogroll

social