1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use std::future::Future;
use std::mem;
use std::pin::Pin;
use std::result;
use std::task::Context;
use std::task::Poll;
use tls_api::async_as_sync::AsyncIoAsSyncIo;
use tls_api::spi::save_context;
use tls_api::AsyncSocket;
pub(crate) enum HandshakeFuture<F, S: Unpin> {
Initial(F, AsyncIoAsSyncIo<S>),
MidHandshake(native_tls::MidHandshakeTlsStream<AsyncIoAsSyncIo<S>>),
Done,
}
impl<F, A> Future for HandshakeFuture<F, A>
where
A: AsyncSocket,
F: FnOnce(
AsyncIoAsSyncIo<A>,
) -> result::Result<
native_tls::TlsStream<AsyncIoAsSyncIo<A>>,
native_tls::HandshakeError<AsyncIoAsSyncIo<A>>,
>,
Self: Unpin,
{
type Output = anyhow::Result<crate::TlsStream<A>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
save_context(cx, || {
let self_mut = self.get_mut();
match mem::replace(self_mut, HandshakeFuture::Done) {
HandshakeFuture::Initial(f, stream) => match f(stream) {
Ok(stream) => {
return Poll::Ready(Ok(crate::TlsStream::new(stream)));
}
Err(native_tls::HandshakeError::WouldBlock(mid)) => {
*self_mut = HandshakeFuture::MidHandshake(mid);
return Poll::Pending;
}
Err(native_tls::HandshakeError::Failure(e)) => {
return Poll::Ready(Err(anyhow::Error::new(e)))
}
},
HandshakeFuture::MidHandshake(stream) => match stream.handshake() {
Ok(stream) => {
return Poll::Ready(Ok(crate::TlsStream::new(stream)));
}
Err(native_tls::HandshakeError::WouldBlock(mid)) => {
*self_mut = HandshakeFuture::MidHandshake(mid);
return Poll::Pending;
}
Err(native_tls::HandshakeError::Failure(e)) => {
return Poll::Ready(Err(anyhow::Error::new(e)))
}
},
HandshakeFuture::Done => panic!("Future must not be polled after ready"),
}
})
}
}