1mod entry;
2mod weight;
3
4use entry::WeightedBagEntry;
5pub use weight::Weight;
6
7#[cfg_attr(
8 feature = "serde",
9 derive(serde::Serialize, serde::Deserialize),
10 serde(from = "Vec<(T, W)>")
11)]
12
13pub struct WeightedBag<T, W: Weight> {
33 entries: Vec<WeightedBagEntry<T, W>>,
34 weight: Option<W>,
35}
36
37impl<T, W: Weight> WeightedBag<T, W> {
38 pub fn add_entry(&mut self, t: T, weight: W) {
42 assert_ne!(weight, W::zero(), "Weightless entries are not allowed");
44
45 if let Some(acc_weight) = &mut self.weight {
48 *acc_weight += weight;
49 } else {
50 self.weight = Some(weight - W::one());
51 }
52
53 self.entries.push(WeightedBagEntry {
54 inner: t,
55 weight: self.weight.clone().unwrap(),
56 })
57 }
58
59 #[inline]
61 pub(crate) fn get(&self, r: W) -> Option<&T> {
62 self.entries.iter().find(|e| e.weight >= r).map(|e| &**e)
63 }
64
65 pub fn try_get_random(&self) -> Option<&T> {
67 let Some(acc_weight) = self.weight.clone() else {
68 return None;
69 };
70
71 self.get(super::get_inc(W::zero(), acc_weight).into())
72 }
73
74 #[inline]
80 pub fn get_random(&self) -> &T {
81 self.try_get_random().unwrap()
82 }
83}
84
85impl<T, W: Weight> From<Vec<(T, W)>> for WeightedBag<T, W> {
86 fn from(items: Vec<(T, W)>) -> Self {
87 let mut new_bag = Self::default();
88 items
89 .into_iter()
90 .for_each(|(item, weight)| new_bag.add_entry(item, weight));
91 new_bag
92 }
93}
94
95impl<T, W: Weight> Default for WeightedBag<T, W> {
96 fn default() -> Self {
97 Self {
98 entries: Vec::new(),
99 weight: None,
100 }
101 }
102}
103
104impl<T: Clone, W: Weight> Clone for WeightedBag<T, W> {
105 fn clone(&self) -> Self {
106 Self {
107 entries: self.entries.clone(),
108 weight: self.weight.clone(),
109 }
110 }
111}
112
113impl<T: std::fmt::Debug, W: Weight + std::fmt::Debug> std::fmt::Debug for WeightedBag<T, W> {
114 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
115 f.debug_struct("WeightedBag")
116 .field("entries", &self.entries)
117 .field("total_weight", &self.weight)
118 .finish()
119 }
120}
121
122#[test]
123fn test() {
124 fn inner_test<T: num_traits::NumCast + Weight>() -> Option<()> {
125 let bag = super::WeightedBag::<&str, T>::from(vec![
126 ("Hi", T::from(2)?), ("Hellow", T::from(1)?), ("Bonjour", T::from(4)?), ("Holà", T::from(4)?), ("こんにちは", T::from(3)?), ("你好", T::from(10)?), ("Olá", T::from(7)?), ("Hej", T::from(5000)?), ]);
135
136 assert_eq!(bag.get(T::zero()), Some(&"Hi"));
139 assert_eq!(bag.get(T::one()), Some(&"Hi"));
140
141 assert_eq!(bag.get(T::from(2)?), Some(&"Hellow"));
142
143 assert_eq!(bag.get(T::from(3)?), Some(&"Bonjour"));
144 assert_eq!(bag.get(T::from(6)?), Some(&"Bonjour"));
145
146 assert_eq!(bag.get(T::from(7)?), Some(&"Holà"));
147 assert_eq!(bag.get(T::from(10)?), Some(&"Holà"));
148
149 assert_eq!(bag.get(T::from(11)?), Some(&"こんにちは"));
150 assert_eq!(bag.get(T::from(13)?), Some(&"こんにちは"));
151
152 assert_eq!(bag.get(T::from(14)?), Some(&"你好"));
153 assert_eq!(bag.get(T::from(23)?), Some(&"你好"));
154
155 assert_eq!(bag.get(T::from(24)?), Some(&"Olá"));
156 assert_eq!(bag.get(T::from(30)?), Some(&"Olá"));
157
158 assert_eq!(bag.get(T::from(31)?), Some(&"Hej"));
159 assert_eq!(bag.get(T::from(5030)?), Some(&"Hej"));
160
161 assert_eq!(bag.get(T::from(5031)?), None::<&&str>);
162
163 Some(())
164 }
165
166 assert_eq!(inner_test::<u8>(), None::<()>); inner_test::<u16>().unwrap(); inner_test::<u32>().unwrap(); inner_test::<u64>().unwrap(); inner_test::<u128>().unwrap(); }