Stupid Lambda Tricks

Stupid Lambda Tricks

  • Comments 26

Hi.  I’m Arjun Bijanki, the test lead for the compiler front-end and Intellisense engine.  One afternoon a few months ago, I was sitting in my office in building 41 thinking about test passes, when an animated discussion between a couple of colleagues spilled into the hallway and grabbed my attention.  My recollection is that Boris Jabes, whom some of you might have seen deliver the “10 is the new 6” talk at PDC last week, was trying to convince the other colleague that you could write an automatic memoization function for C++0x lambdas.

I became intrigued.

For those that haven’t used C++0x lambdas before, the feature provides a way to define unnamed function objects.  VCBlogger STL wrote up a great post that describes lambdas and their syntax in more detail.

Interestingly, lambdas can be recursive.  For example, here’s a lambda that implements the fibonacci series using recursion:

#include<iostream>

#include<functional>

using namespace std;

using namespace tr1;

 

int main()

{

// implement fib using tr1::function

      function<int(int)> fib1 = [&fib1](int n) -> int

      {

            if(n <= 2)

                  return 1;

            else

                  return fib1(n-1) + fib1(n-2);

      };

 

      cout<<fib1(6);

}

Note that we had to actually name the lambda in order to implement the recursion.  Without a name, how would you make the recursive call?  This self-referencing places some restrictions on how a recursive lambda can be used; the lambda doesn’t refer to itself, it refers to fib1, which happens to be itself.  The restriction is subtle, and shown by the following code:

       function<int(int)> fib1_copy = fib1; // copy fib1

      fib1 = [](int n) { return -1; };    // fib1 now does something else

      cout<<fib1_copy(6);                 // uh oh, doesn't do what we expect

 

Quiz: what would fib1_copy(6) return?

 

Anyway, fib1 isn’t a particularly efficient implementation of the algorithm, but, even though I don’t come from a functional programming background, I find the recursive solution has a certain elegance.  By changing this function to cache values it has already computed, we can make it faster:

function<int(int)> fib2 = [&fib2](int n) -> int

{

      static map<int,int> cache;

      if(n <= 2)

            return 1;

      else if(cache.find(n) == cache.end())

      {

            cache[n] = fib2(n-1) + fib2(n-2);

      }

      return cache[n];

};

(At this point, I should make the caveat that we’re really getting into parlor trick territory.  A functor class is a far better bet for production code, and can do everything here, and more.)

Now the function is really hard to read, and I may as well just implement it iteratively.  This is where the automatic memoization function comes in.  What we’d really like to write is something nice and clean like fib1, but allowing us to memoize it for the faster computation of fib2:

      // memoize the fib1 function and find fib(6)

      memoize(fib1)(6);

 

The three of us tried for an hour or so to write the memoize()function, but intercepting fib1’s recursion to insert a cache proved difficult.

 

// what goes into the highlighted calls? 

// we want to use a cache, not call fib directly.

      return fib1(n-1) + fib1(n-2);

 

If we had a function to independently manage the cache, we make the recursive call this way:

      return check_cache(fib1,n-1) + check_cache(fib1,n-2);

 

This is much cleaner than fib2, but even so, I had to intentionally write fib1 to make use of the cache.  It would be a nicer if I didn’t have to do that.  Eventually, we figured out that we needed a way to hook the recursion and insert our own adapter that checks the cache before making the recursive call.  That way, we can write fib1 normally, and under the covers check the cache. tr1::function doesn’t seem to support this, so after a couple of tries, I coded up adaptable_function and memoize_adapter (Warning: Templates Ahead).

 

template<class Arg>

struct adaptable_function

{

      tr1::function<Arg(Arg)> func;

 

      typedef tr1::function<Arg(tr1::function<Arg(Arg)>,Arg)> adapter_type;

      adapter_type adapter;

 

      // binds a function

      adaptable_function(tr1::function<Arg(Arg)> const& f) : func(f)

      {

      }

 

      // invokes the bound function through an adapter, if one exists

      Arg operator()(Arg p) const

      {

            if(adapter)

                  return adapter(func, p);

            else

                  return func(p);

      }

 

      void set_adapter(adapter_type const& a)

      {

            adapter = a;

      }

      void clear_adapter()

      {

            adapter = adapter_type(); // better way to clear a tr1::function?

      }

 

private:

      //relies on self-referential recursion, so the class is non-copyable

      adaptable_function(adaptable_function const&);

 

};

 

template<class Arg>

struct memoize_adapter

{

      map<Arg,Arg> cache;

 

      Arg operator()(adaptable_function<Arg> func, Arg arg)

      {

            if(cache.find(arg) == cache.end())

            {

                  cache[arg] = func(arg);

            }

            return cache[arg];

      };

};

 

template<class Arg>

adaptable_function<Arg>& memoize(adaptable_function<Arg>& f)

{

      f.set_adapter(memoize_adapter<Arg>());

      return f;

};

 

Now, I can implement fib:

·         Using the elegant recursive algorithm

·         Using a cache without having to build a cache into the function

Here’s the calling code:

int main()

{

      adaptable_function<int> fib = [&] (int n) -> int {

            cout<<"fib("<<n<<")"<<endl; // to show calls

            if(n <= 2)

                  return 1;

            else

                  return fib(n-1) + fib(n-2);

      };

 

      cout<<"normal result = "<<fib(6)<<endl<<endl;

 

      fib.set_adapter(memoize_adapter<int>());

      cout<<"memoized result = "<<fib(6)<<endl<<endl;

 

      fib.clear_adapter();

      cout<<"normal result = "<<fib(6)<<endl<<endl;

 

      cout<<"memoized result = "<<memoize(fib)(6)<<endl<<endl;

}

 

This does the trick!

There are several things I don’t really like about this code.  First and foremost is that memoizing fib  changes its state.  memoize(fib) doesn’t just return a memoized version of fib, it changes fib’s recursive call.  I think this is a limitation of recursive lambdas, since they have to be named.  Or do they…?

Can anyone make this more elegant in C++?  (Note that Bart De Smet recently wrote on this topic from a C# perspective.  It’s a great post, and worth the read!)

-          Arjun

  • PingBack from http://mstechnews.info/2008/11/stupid-lambda-tricks/

  • [Arjun]

    > adapter = adapter_type(); // better way to clear a tr1::function?

    Yes: adapter = 0; or adapter = NULL;

    This is terser and slightly more efficient.

  • [Arjun]

    > (At this point, I should make the caveat that we’re really getting into parlor trick territory.

    > A functor class is a far better bet for production code, and can do everything here, and more.)

    I can demonstrate how to use named functors here.

    Memoizing an individual functor is easy:

    C:\Temp>type meow.cpp

    #include <iostream>

    #include <map>

    #include <memory>

    #include <ostream>

    using namespace std;

    using namespace std::tr1;

    class fib {

    private:

       typedef map<int, int> map_t;

       typedef map_t::const_iterator map_ci_t;

       shared_ptr<map_t> m_map;

    public:

       fib() : m_map(new map_t) { }

       int operator()(const int n) {

           const map_ci_t i = m_map->find(n);

           if (i == m_map->end()) {

               cout << "fib()(" << n << ")" << endl;

               const int ret = n < 2 ? n : (*this)(n - 1) + (*this)(n - 2);

               (*m_map)[n] = ret;

               return ret;

           } else {

               return i->second;

           }

       }

    };

    int main() {

       cout << "Result: " << fib()(6) << endl;

    }

    C:\Temp>cl /EHsc /nologo /W4 meow.cpp

    meow.cpp

    C:\Temp>meow

    fib()(6)

    fib()(5)

    fib()(4)

    fib()(3)

    fib()(2)

    fib()(1)

    fib()(0)

    Result: 8

    Note that functors should be efficiently copyable, because they're usually passed by value, which is why fib stores shared_ptr<map_t>.

  • If you're memoizing lots of functors, you may want to lift out the memoization into a separate class. Here's one way to do it:

    C:\Temp>type purr.cpp

    #include <exception>

    #include <iostream>

    #include <map>

    #include <memory>

    #include <ostream>

    #include <string>

    #include <utility>

    using namespace std;

    using namespace std::tr1;

    template <typename StatelessFunctor, typename Result, typename Arg1> class memoizer {

    private:

       typedef map<Arg1, Result> map_t;

       typedef typename map_t::const_iterator map_ci_t;

       shared_ptr<map_t> m_map;

    public:

       memoizer() : m_map(new map_t) { }

       Result operator()(const Arg1& arg1) {

           const map_ci_t i = m_map->find(arg1);

           if (i == m_map->end()) {

               const Result& ret = StatelessFunctor()(*this, arg1);

               m_map->insert(make_pair(arg1, ret));

               return ret;

           } else {

               return i->second;

           }

       }

    };

    struct fib {

       int operator()(memoizer<fib, int, int>& mem, const int n) const {

           cout << "fib()(" << n << ")" << endl;

           return n < 2 ? n : mem(n - 1) + mem(n - 2);

       }

    };

    struct squarefree {

       string operator()(memoizer<squarefree, string, int>& mem, const int n) const {

           cout << "squarefree()(" << n << ")" << endl;

           if (n == 0) {

               return "0";

           } else {

               const string prev = mem(n - 1);

               string ret;

               for (string::const_iterator i = prev.begin(); i != prev.end(); ++i) {

                   switch (*i) {

                       case '0':

                           ret += "12";

                           break;

                       case '1':

                           ret += "102";

                           break;

                       case '2':

                           ret += "0";

                           break;

                       default:

                           cout << "EPIC FAIL" << endl;

                           terminate();

                   }

               }

               return ret;

           }

       }

    };

    int main() {

       memoizer<fib, int, int> mem_fib;

       cout << "mem_fib(6): " << mem_fib(6) << endl << endl;

       cout << "mem_fib(8): " << mem_fib(8) << endl << endl;

       memoizer<squarefree, string, int> mem_sqf;

       cout << "mem_sqf(3): " << mem_sqf(3) << endl << endl;

       cout << "mem_sqf(5): " << mem_sqf(5) << endl << endl;

    }

  • Output:

    C:\Temp>cl /EHsc /nologo /W4 purr.cpp

    purr.cpp

    C:\Temp>purr

    fib()(6)

    fib()(5)

    fib()(4)

    fib()(3)

    fib()(2)

    fib()(1)

    fib()(0)

    mem_fib(6): 8

    fib()(8)

    fib()(7)

    mem_fib(8): 21

    squarefree()(3)

    squarefree()(2)

    squarefree()(1)

    squarefree()(0)

    mem_sqf(3): 10212012

    squarefree()(5)

    squarefree()(4)

    mem_sqf(5): 10212010201210212012102010212012

  • This way, fib and squarefree can focus on their computations, leaving all of the caching work to memoizer.

    Note that I've generalized this to handle arbitrary result and argument types, but it requires stateless unary functors. Generalizing this further to handle stateful functors of arbitrary arity is an exercise left to the reader (a stateful functor would need to be stored within the memoizer, and the memoizer's map would need to be from a tuple of argument types to the result type).

    Also note that I use insert() instead of op[]() because op[]() default-constructs the value if it's not present, and in general result types don't have to have default constructors.

  • This is the best I could do in C++03. Fun exercise, thanks!

    #include <boost/function.hpp>

    #include <map>

    #include <iostream>

    template<typename Signature>

    struct memoized;

    template<typename R,typename T1>

    struct memoized<R(T1)>

    {

     typedef std::map<T1,R> cache_t;

     typedef boost::function<R(memoized<R(T1)> &,T1)> memo_function_t;

     explicit

     memoized(memo_function_t const & f):

       m_f(f)

     {}

     R operator()(T1 t1)

     {

       typename cache_t::iterator it = m_cache.find(t1);

       if(it == m_cache.end())

       {

         std::cout << t1 << " not in cache." << std::endl;

         return m_cache[t1] = m_f(*this,t1);

       }

       else

       {

         std::cout << t1 << " in cache." << std::endl;

         return it->second;

       }

     }

    private:

     memo_function_t m_f;

     cache_t         m_cache;

    };

    template<typename R, typename T1>

    memoized<R(T1)>

    memoize(R(*f)(memoized<R(T1)> &,T1))

    {

     return memoized<R(T1)>(f);

    }

    int main()

    {

     struct myfib

     {

       static std::size_t calc(memoized<std::size_t(std::size_t)> & memo_fib,

                               std::size_t n)

       {

         if(n <= 2) return 1;

         else return memo_fib(n-1) + memo_fib(n-2);

       }

     };

     boost::function<std::size_t(std::size_t)> fib =

       memoize(myfib::calc);

     std::cout << fib(10) << std::endl;

     std::cout << fib(10) << std::endl;

    }

  • Is it not possible to define a Y-combinator for C++ lambdas?  The Y-combinator is how you get recursive anonymous functions in many other languages.

    http://en.wikipedia.org/wiki/Fixed_point_combinator

  • Shouldn't

    fib1 = [](int n) { return -1; };

    be

    fib1 = [](int n) -> int { return -1; };

    ?

    Why does the first one even compile? Is the return value by default int?

  • Also why does

    fib1 = [](int n) -> double { return -1; };

    compile? So function<> isn't type safe?

  • Interesting acrobatics, but I am a KISS fan.

    I prefer not to mandate a C++ black belt (with several Dans on occassion) on coworkers who try to understand and modify my code, so thanks but I'll pass.

    Is there anything in the above code that cannot be done in plain C in a way that 90% of the dev population can understand and 80% can modify/extend without a mistake?

    Why do architects feel so compelled to save the world by providing infrastructure and plumbing for everything conceivable under the sun?

    What about memoization? If I am in such a corner case where caching the results of a function call will *actually* improve performance, what makes you think I would opt for an obscure and totally incomprehensible generic template that I cannot understand or debug, rather than a custom-tailored, totally non-reusable, top-performing, totally understandable and debugable solution?

    Don't get me wrong, I am not an anti-STL, do-it-yourself (CMyHashTable, CMyDynamicArray, CMyOS) gangho. I am just a KISS fan (including the rock band). If something can be done in a way that is simpler, easier to understand, debug and extend, then I prefer the simpler way.

    I just get so frustrated when people do all this acrobatic stuff in production code just because (a) they can do it (b) it's cool to do it, without thinking back a lil'bit or actually having mastered the 'tools' they are using.

    A similar example is 'patternitis'. I have seen countless C++ freshmen reading the GangOf4 Design Patterns book and then creating a total mess in everything, like deciding to implement the Visitor pattern on a problem that required Composite and ended up coding a third pattern alltogether from the same book, still naming the classes CVisitorXYZ (probably they opened the book on the wrong page at some point).

    I have met exactly 1 guy (I called him the "Professor") who knew C++ well enough and had the knowledge to apply the patterns where they ought to be applied. His code was a masterpiece, it worked like a breeze, but when he left, noone else in the house could figure things out.

    So what's the point with these Lambda stuff really? Increase the expression of the language? Are we doing poetry or software? Why should we turn simple code that everyone understands into more and more elegant and concise code that only few can understand and make it work?

    I have been coding in C (drivers) and C++ for 15 years and not once was I trapped because I was missing lambda expressions or similar syntactic gizmos.

    So what's the point really? Please enlighten me. I don't say that *I* am right and *YOU* are wrong. I am saying that I don't see, I don't understand the positive value that these things bring in that far outweighs the problems they cause by complicating the language.

    I am all ears :-)

    Warm Regards,

    Dimitris Staikos

  • I'm with Dimitris Staikos / bruteforce on this.

    Fix the core of the IDE and compiler first. Then if you have time add these features. I don't hear people screaming for them. On the other hand, I do hear people screaming about poor IDE performance, documentation issues and poor code analysis/warnings.

    Mike

  • Well, first off there's no need for this std::function<> malarky --- just use auto:

    auto fib = [&](int n) -> int{

       return (n<=2)?1:(fib(n-1)+fib(n-2));

    };

    Secondly, with auto and decltype we can write a memoize adapter for any callable function by using decltype to get the return type of the function call. [Warning: untested code follows]

    template<typename T>

    struct memoized

    {

       T func;

       explicit memoized(T func_):

           func(func_)

       {}

       struct adaptable

       {

           struct base

           {

               virtual ~base()

               {}

               virtual const std::type_info& type() const=0;

               virtual bool is_less(base const& rhs) const=0;

               virtual bool is_equal(base const& rhs) const=0;

           };

           template<typename U>

           struct value:

               base

           {

               U data;

               value(U data_):

                   data(data_)

               {}

               virtual const std::type_info& type() const

               {

                   return typeid(U);

               }

               virtual bool is_less(base const& rhs) const

               {

                   return type().before(rhs.type()) ||

                       ((type()==rhs.type()) &&

                        (data<static_cast<value const&>(rhs).data));

               }

               virtual bool is_equal(base const& rhs) const

               {

                   return (type()==rhs.type()) &&

                       (data==static_cast<value const&>(rhs).data);

               }

            };

           std::auto_ptr<base> data;

           template<typename U>

           explicit adaptable(U data_):

               data(new value<U>(data_))

           {}

           template<typename U>

           U& extract() const

           {

               if(typeid(U)==data->type())

               {

                   return static_cast<value<U>*>(data.get())->data;

               }

               else

               {

                   throw "wrong type";

               }

           }

           bool operator<(adaptable const& rhs) const

           {

               return data->is_less(rhs.data);

           }

           bool operator==(adaptable const& rhs) const

           {

               return data->is_equal(rhs.data);

           }

       };

       typedef std::map<adaptable,adaptable> map_type;

       std::shared_ptr<map_type> memo;

       memoized(T func_):

           func(func_),

           memo(new map_type)

       {}

       template<typename U>

       auto operator()(U arg) -> decltype(func(arg))

       {

           typedef decltype(func(arg)) result_type;

           adaptable adapted_arg(arg);

           map_type::iterator existing_entry=memo->find(adapted_arg);

           if(existing_entry==memo->end())

           {

               result_type res=func(arg);

               memo->insert(map_type::value_type(adapted_arg,adaptable(res)));

               return res;

           }

           else

           {

               return existing_entry->second.cast<result_type>();

           }

       }

    };

    template<typename T>

    memoized<T> memoize(T func)

    {

       return memoized<T>(func);

    }

    int main()

    {

       auto f1 = [&f1] (int n) -> int {

           std::cout<<"f1("<<n<<")"<<std::endl;

           if(n<=2)

           {

               return 1;

           }

           else

           {

               return f1(n-1) + f1(n-2);

           }

       };

       f1(10);

       auto f2 = memoize([&f2] (int n) -> int {

               std::cout<<"f2("<<n<<")"<<std::endl;

               if(n<=2)

               {

                   return 1;

               }

               else

               {

                   return f2(n-1) + f2(n-2);

               }

           });

       f2(10);

    }

  • @Andre:

    In

    fib1 = [](int n) { return -1; };

    the body of the lambda is a single return statement, so return type is the type of the expression in the return statement.

  • Lambdas took the place previously reserved for generic programming :)

    What's Andrei Alexandrescu doing these days?

Page 1 of 2 (26 items) 12