random/
weighted_bag.rs

1mod entry;
2mod weight;
3
4use entry::WeightedBagEntry;
5pub use weight::Weight;
6
7#[cfg_attr(
8    feature = "serde",
9    derive(serde::Serialize, serde::Deserialize),
10    serde(from = "Vec<(T, W)>")
11)]
12
13/// A `WeightedBag` is a collection that holds entries of type `T` with associated weights of type `W`.  
14/// The weights determine the likelihood of selecting each entry when retrieving a random item from the bag.
15///
16/// # Type Parameters
17/// - `T`: The type of the entries stored in the bag.
18/// - `W`: A type that implements the [Weight] trait, representing the weight of each entry.
19///
20/// # Features
21/// This struct can derive `Serialize` and `Deserialize` traits when the `serde` feature is enabled.  
22/// It can also be constructed from a vector of tuples `Vec<(T, W)>` containing entries and their corresponding weights.
23///
24/// # Example
25/// ```
26/// let mut bag: random::WeightedBag<&str, u32> = random::WeightedBag::default();
27/// bag.add_entry("apple", 2);
28/// bag.add_entry("banana", 1);
29/// let random_fruit: Option<&&str> = bag.try_get_random();
30/// ```
31
32pub struct WeightedBag<T, W: Weight> {
33    entries: Vec<WeightedBagEntry<T, W>>,
34    weight: Option<W>,
35}
36
37impl<T, W: Weight> WeightedBag<T, W> {
38    /// Adds an entry with given weight to the bag
39    ///
40    /// Panics if the weight is 0
41    pub fn add_entry(&mut self, t: T, weight: W) {
42        // Doesn't make sense + would break the system
43        assert_ne!(weight, W::zero(), "Weightless entries are not allowed");
44
45        // This is pretty ugly but the other way is to use a signed integer type, which would be dumb (waste of half the memory used)
46        // We could use 0 as base and if 0 { self.w = w -1} but adding a weight of 1 as first entry would make the initialisation loop (+ sentinel values are stoopid to use when you have a rich type system)
47        if let Some(acc_weight) = &mut self.weight {
48            *acc_weight += weight;
49        } else {
50            self.weight = Some(weight - W::one());
51        }
52
53        self.entries.push(WeightedBagEntry {
54            inner: t,
55            weight: self.weight.clone().unwrap(),
56        })
57    }
58
59    // I needed this part to be it's own method for tests, and since it's inlined, i don't see it being any different than not
60    #[inline]
61    pub(crate) fn get(&self, r: W) -> Option<&T> {
62        self.entries.iter().find(|e| e.weight >= r).map(|e| &**e)
63    }
64
65    /// Retrieve a random entry from the bag, chances are based on weight
66    pub fn try_get_random(&self) -> Option<&T> {
67        let Some(acc_weight) = self.weight.clone() else {
68            return None;
69        };
70
71        self.get(super::get_inc(W::zero(), acc_weight).into())
72    }
73
74    /// Short hand for [WeightedBag::try_get_random].unwrap()
75    ///
76    /// # Panics if:
77    ///
78    /// - The bag is empty
79    #[inline]
80    pub fn get_random(&self) -> &T {
81        self.try_get_random().unwrap()
82    }
83}
84
85impl<T, W: Weight> From<Vec<(T, W)>> for WeightedBag<T, W> {
86    fn from(items: Vec<(T, W)>) -> Self {
87        let mut new_bag = Self::default();
88        items
89            .into_iter()
90            .for_each(|(item, weight)| new_bag.add_entry(item, weight));
91        new_bag
92    }
93}
94
95impl<T, W: Weight> Default for WeightedBag<T, W> {
96    fn default() -> Self {
97        Self {
98            entries: Vec::new(),
99            weight: None,
100        }
101    }
102}
103
104impl<T: Clone, W: Weight> Clone for WeightedBag<T, W> {
105    fn clone(&self) -> Self {
106        Self {
107            entries: self.entries.clone(),
108            weight: self.weight.clone(),
109        }
110    }
111}
112
113impl<T: std::fmt::Debug, W: Weight + std::fmt::Debug> std::fmt::Debug for WeightedBag<T, W> {
114    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
115        f.debug_struct("WeightedBag")
116            .field("entries", &self.entries)
117            .field("total_weight", &self.weight)
118            .finish()
119    }
120}
121
122#[test]
123fn test() {
124    fn inner_test<T: num_traits::NumCast + Weight>() -> Option<()> {
125        let bag = super::WeightedBag::<&str, T>::from(vec![
126            ("Hi", T::from(2)?),         // 0..=1
127            ("Hellow", T::from(1)?),     //  =2
128            ("Bonjour", T::from(4)?),    //  3..=6
129            ("Holà", T::from(4)?),       //  7..=10
130            ("こんにちは", T::from(3)?), // 11..=13
131            ("你好", T::from(10)?),      // 14..=23
132            ("Olá", T::from(7)?),        // 24..=30
133            ("Hej", T::from(5000)?),     // 31..=5030
134        ]);
135
136        // dbg!(&bag);
137
138        assert_eq!(bag.get(T::zero()), Some(&"Hi"));
139        assert_eq!(bag.get(T::one()), Some(&"Hi"));
140
141        assert_eq!(bag.get(T::from(2)?), Some(&"Hellow"));
142
143        assert_eq!(bag.get(T::from(3)?), Some(&"Bonjour"));
144        assert_eq!(bag.get(T::from(6)?), Some(&"Bonjour"));
145
146        assert_eq!(bag.get(T::from(7)?), Some(&"Holà"));
147        assert_eq!(bag.get(T::from(10)?), Some(&"Holà"));
148
149        assert_eq!(bag.get(T::from(11)?), Some(&"こんにちは"));
150        assert_eq!(bag.get(T::from(13)?), Some(&"こんにちは"));
151
152        assert_eq!(bag.get(T::from(14)?), Some(&"你好"));
153        assert_eq!(bag.get(T::from(23)?), Some(&"你好"));
154
155        assert_eq!(bag.get(T::from(24)?), Some(&"Olá"));
156        assert_eq!(bag.get(T::from(30)?), Some(&"Olá"));
157
158        assert_eq!(bag.get(T::from(31)?), Some(&"Hej"));
159        assert_eq!(bag.get(T::from(5030)?), Some(&"Hej"));
160
161        assert_eq!(bag.get(T::from(5031)?), None::<&&str>);
162
163        Some(())
164    }
165
166    assert_eq!(inner_test::<u8>(), None::<()>); // Fails on T::from(5000)
167    inner_test::<u16>().unwrap(); // should pass
168    inner_test::<u32>().unwrap(); // should pass
169    inner_test::<u64>().unwrap(); // should pass
170    inner_test::<u128>().unwrap(); // should pass
171}