1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
//! Stream utilities to help implement
//! [`AbstractCircMgr`](`super::AbstractCircMgr.`)

use futures::stream::{Fuse, FusedStream, Stream, StreamExt};
use futures::task::{Context, Poll};
use pin_project::pin_project;
use std::pin::Pin;

/// Enumeration to indicate which of two streams provided a result.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(super) enum Source {
    /// Indicates a result coming from the left (preferred) stream.
    Left,
    /// Indicates a result coming from the right (secondary) stream.
    Right,
}

/// A stream returned by [`select_biased`]
///
/// See that function for more documentation.
#[pin_project]
pub(super) struct SelectBiased<S, T> {
    /// Preferred underlying stream.
    ///
    /// When results are available from both streams, we always yield them
    /// from this one.  When this stream is exhausted, the `SelectBiased`
    /// is exhausted too.
    #[pin]
    left: Fuse<S>,
    /// Secondary underlying stream.
    #[pin]
    right: Fuse<T>,
}

/// Combine two instances of [`Stream`] into one.
///
/// This function is similar to [`futures::stream::select`], but differs
/// in that it treats the two underlying streams asymmetrically.  Specifically:
///
///  * Each result is labeled with [`Source::Left`] or
///    [`Source::Right`], depending on which of the two streams it came
///    from.
///  * If both the "left" and the "right" stream are ready, we always
///    prefer the left stream.
///  * We stop iterating over this stream when there are no more
///    results on the left stream, regardless whether the right stream
///    is exhausted or not.
///
/// # Future plans
///
/// This might need a better name, especially if we use it anywhere
/// else.
///
/// If we do expose this function, we might want to split up the ways in
/// which it differs from `select`.
pub(super) fn select_biased<S, T>(left: S, right: T) -> SelectBiased<S, T>
where
    S: Stream,
    T: Stream<Item = S::Item>,
{
    SelectBiased {
        left: left.fuse(),
        right: right.fuse(),
    }
}

impl<S, T> FusedStream for SelectBiased<S, T>
where
    S: Stream,
    T: Stream<Item = S::Item>,
{
    fn is_terminated(&self) -> bool {
        // We're done if the left stream is done, whether the right stream
        // is done or not.
        self.left.is_terminated()
    }
}

impl<S, T> Stream for SelectBiased<S, T>
where
    S: Stream,
    T: Stream<Item = S::Item>,
{
    type Item = (Source, S::Item);

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.project();
        // We always check the left stream first.
        match this.left.poll_next(cx) {
            Poll::Ready(Some(val)) => {
                // The left stream has an item: yield it.
                return Poll::Ready(Some((Source::Left, val)));
            }
            Poll::Ready(None) => {
                // The left stream is exhausted: don't even check the right.
                return Poll::Ready(None);
            }
            Poll::Pending => {}
        }

        // The left stream is pending: see whether the right stream has
        // anything to say.
        match this.right.poll_next(cx) {
            Poll::Ready(Some(val)) => {
                // The right stream has an item: yield it.
                Poll::Ready(Some((Source::Right, val)))
            }
            _ => {
                // The right stream is exhausted or pending: in either case,
                // we need to wait.
                Poll::Pending
            }
        }
    }
}

#[cfg(test)]
mod test {
    #![allow(clippy::unwrap_used)]
    use super::*;
    use futures_await_test::async_test;

    // Tests where only elements from the left stream should be yielded.
    #[async_test]
    async fn left_only() {
        use futures::stream::iter;
        use Source::Left as L;
        // If there's nothing in the right stream, we just yield the left.
        let left = vec![1_usize, 2, 3];
        let right = vec![];

        let s = select_biased(iter(left), iter(right));
        let result: Vec<_> = s.collect().await;
        assert_eq!(result, vec![(L, 1_usize), (L, 2), (L, 3)]);

        // If the left runs out (which this will), we don't yield anything
        // from the right.
        let left = vec![1_usize, 2, 3];
        let right = vec![4, 5, 6];
        let s = select_biased(iter(left), iter(right));
        let result: Vec<_> = s.collect().await;
        assert_eq!(result, vec![(L, 1_usize), (L, 2), (L, 3)]);

        // The same thing happens if the left stream is completely empty!
        let left = vec![];
        let right = vec![4_usize, 5, 6];
        let s = select_biased(iter(left), iter(right));
        let result: Vec<_> = s.collect().await;
        assert_eq!(result, vec![]);
    }

    // Tests where only elements from the right stream should be yielded.
    #[async_test]
    async fn right_only() {
        use futures::stream::{iter, pending};
        use Source::Right as R;

        // Try a forever-pending stream for the left hand side.
        let left = pending();
        let right = vec![4_usize, 5, 6];
        let mut s = select_biased(left, iter(right));
        assert_eq!(s.next().await, Some((R, 4)));
        assert_eq!(s.next().await, Some((R, 5)));
        assert_eq!(s.next().await, Some((R, 6)));
    }

    // Tests where we can find elements from both streams.
    #[async_test]
    async fn multiplex() {
        use futures::SinkExt;
        use Source::{Left as L, Right as R};

        let (mut snd_l, rcv_l) = futures::channel::mpsc::channel(5);
        let (mut snd_r, rcv_r) = futures::channel::mpsc::channel(5);
        let mut s = select_biased(rcv_l, rcv_r);

        snd_l.send(1_usize).await.unwrap();
        snd_r.send(4_usize).await.unwrap();
        snd_l.send(2_usize).await.unwrap();

        assert_eq!(s.next().await, Some((L, 1)));
        assert_eq!(s.next().await, Some((L, 2)));
        assert_eq!(s.next().await, Some((R, 4)));

        snd_r.send(5_usize).await.unwrap();
        snd_l.send(3_usize).await.unwrap();

        assert!(!s.is_terminated());
        drop(snd_r);

        assert_eq!(s.next().await, Some((L, 3)));
        assert_eq!(s.next().await, Some((R, 5)));

        drop(snd_l);
        assert_eq!(s.next().await, None);

        assert!(s.is_terminated());
    }
}