DZone Snippets is a public source code repository. Easily build up your personal collection of code snippets, categorize them with tags / keywords, and share them with the world

Snippets has posted 5883 posts at DZone. View Full User Profile

Weighted Random In C# (csharp)

04.26.2007
| 1108 views |
  • submit to reddit
        weighted random function in c# 3.0. For c# 2.0 just remove 'this' before WeightedRandom rnd argument.
    public static partial class Utils
    {
        static string WeightCountMustBeGreaterThanZero="Weight count must be greater than zero.";
        static string ElementWeightMustBeGreaterThanOrEqualToZero="Weight must be greater than or equal to zero";
        /// <summary>
        /// Returns random index in weights list with probability based on its weight value.
        /// </summary>
        /// <param name="rnd"></param>
        /// <param name="weights">List of weights.</param>
        /// <returns></returns>
        /// <exception cref="ArgumentNullException">Throws an exception if weights is null.</exception>
        /// <exception cref="ArgumentOutOfRangeException">
        /// Throws an exception if weights count is zero.
        /// Throws an exception if any weight is less than zero.
        /// </exception>
        /// <remarks>
        /// Returned values are within range of zero and weights.Count (exclusive).
        /// Chance of returned value to be i is weights[i]/weights.Sum().
        /// Any weight can be equal to zero. Such index is never selected.
        /// </remarks>
        /// <example>
        /// < code>
        /// var weights=new List<int>(new int[]{2,3,5,0});
        /// int v=new Random().WeightedRandom(weights);
        /// </ code>
        /// 20% chance for v==0
        /// 30% chance for v==1
        /// 50% chance for v==2
        /// 0% chance for v==3
        /// </example>
        public static int WeightedRandom(this Random rnd, IList<int> weights)
        {
            if (weights == null)
            {
                throw new ArgumentNullException("weights");
            }
            if (weights.Count == 0)
            {
                throw new ArgumentOutOfRangeException("weights", WeightCountMustBeGreaterThanZero);
            }
            List<int> total_weights = new List<int>();
            for (int i = 0; i < weights.Count; i++)
            {
                if (weights[i] < 0)
                {
                    throw new ArgumentOutOfRangeException("weights", ElementWeightMustBeGreaterThanOrEqualToZero);
                }
                int last;
                if (total_weights.Count > 0)
                {
                    last = total_weights[total_weights.Count - 1];
                }
                else
                {
                    last = 0;
                }
                int w = checked(last + weights[i]);
                total_weights.Add(w);
            }
            int total_random = rnd.Next(total_weights[total_weights.Count - 1]);
            for (int i = 0; i < weights.Count; i++)
            {
                if (weights[i] > total_random)
                {
                    return i;
                }
                total_random -= weights[i];
            }
            throw new Exception();
        }
    }

NUnit tests:

    [TestFixture]
    public class TWeightedRandom
    {
        [Test]
        public void WeightedRandom()
        {
            var weights = new List<int>(new int[] { 1, 0, 2, 3 });
            List<int> l = new List<int>();
            int n = 1000 * 1000;
            Random rnd = new Random();
            for (int i = 0; i < n; i++)
            {
                l.Add(rnd.WeightedRandom(weights));
            }
            int a = l.Where(v => v == 0).Count();
            int b = l.Where(v => v == 2).Count();
            int c = l.Where(v => v == 3).Count();
            int z = l.Where(v => v == 1).Count();
            Assert.AreEqual(n, a + b + c);
            Assert.AreEqual(0, z);
            Assert.Less(Math.Abs((double)b / a - 2), 0.1);
            Assert.Less(Math.Abs((double)c / a - 3), 0.1);
        }
        [Test]
        public void WeightedRandomOverflow()
        {
            int num = 1000 * 1000 * 1000;
            var weights = new List<int>(new int[] { 2 * num, 2 * num });
            try
            {
                new Random().WeightedRandom(weights);
                Assert.Fail("overflow not thrown");
            }
            catch (ArithmeticException e) { }
        }
    }