brane_prx/
redirect.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
//  REDIRECT.rs
//    by Lut99
//
//  Created:
//    23 Nov 2022, 11:26:46
//  Last edited:
//    03 Jan 2024, 14:54:55
//  Auto updated?
//    Yes
//
//  Description:
//!   Implements handlers for the proxy paths.
//

use std::future::Future;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;

use brane_cfg::certs::{load_certstore, load_identity};
use brane_cfg::info::Info as _;
use brane_cfg::node::{NodeConfig, NodeSpecificConfig};
use brane_cfg::proxy::ProxyProtocol;
use log::{debug, error, info};
use never_say_never::Never;
use rustls::client::ClientConfig;
use rustls::{Certificate, ConfigBuilder, PrivateKey, RootCertStore, ServerName};
use socksx::{Socks5Client, Socks6Client};
use specifications::address::Address;
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsConnector;
use tokio_rustls::client::TlsStream;
use url::Url;

pub use crate::errors::RedirectError as Error;
use crate::spec::{Context, NewPathRequestTlsOptions};


/***** AUXILLARY STRUCTS *****/
/// Wrapper around both a Socks6Client and a normal "client" to serve proxied or not.
pub enum RemoteClient {
    /// A normal client without any proxying (which for us is just the address).
    Direct,
    /// A socksx-client that proxies using the SOCKS5-protocol.
    Socks5(Socks5Client, Address),
    /// A socksx-client that proxies using the SOCKS6-protocol.
    Socks6(Socks6Client, Address),
}

impl RemoteClient {
    /// Establish a new connection with the remote host.
    ///
    /// # Arguments
    /// - `address`: The address to connect to.
    ///
    /// # Returns
    /// A TcpStream that represents the connetion.
    async fn connect(&self, address: impl AsRef<str>) -> Result<TcpStream, Error> {
        let address: &str = address.as_ref();

        use RemoteClient::*;
        match self {
            Direct => match TcpStream::connect(address).await {
                Ok(conn) => Ok(conn),
                Err(err) => Err(Error::TcpStreamConnectError { address: address.into(), err }),
            },

            Socks5(client, proxy) => match client.connect(address.to_string()).await {
                Ok((conn, addr)) => {
                    debug!("{:?}", addr);
                    Ok(conn)
                },
                Err(err) => Err(Error::Socks5ConnectError { address: address.into(), proxy: proxy.clone(), err }),
            },
            Socks6(client, proxy) => match client.connect(address.to_string(), None, None).await {
                Ok((conn, addr)) => {
                    debug!("{:?}", addr);
                    Ok(conn)
                },
                Err(err) => Err(Error::Socks6ConnectError { address: address.into(), proxy: proxy.clone(), err }),
            },
        }
    }
}





/***** LIBRARY *****/
/// Initializes a path server with the given port and address.
///
/// # Arguments
/// - `context`: The context that is used for the server.
/// - `socket_addr`: The SocketAddress on which to serve.
/// - `remote_addr`: The address to redirect the traffic to.
/// - `tls`: If given, adds TLS encryption to the remote host with the given address.
///
/// # Returns
/// A Future implementing the path server.
///
/// # Errors
/// This function errors if we failed to bind a TCP server on the given port.
pub async fn path_server_factory(
    context: &Arc<Context>,
    socket_addr: SocketAddr,
    remote_addr: String,
    tls: Option<NewPathRequestTlsOptions>,
) -> Result<impl Future<Output = Never>, Error> {
    // Parse the address to discover the hostname
    let remote_addr: Url = match Url::from_str(&remote_addr) {
        Ok(url) => url,
        Err(err) => {
            return Err(Error::IllegalUrl { raw: remote_addr, err });
        },
    };
    let hostname: &str = match remote_addr.domain() {
        Some(hostname) => hostname,
        None => {
            return Err(Error::NoDomainName { raw: remote_addr.to_string() });
        },
    };

    // Parse the given domain as a hostname first, if required by TLS
    let tls: Option<(ServerName, NewPathRequestTlsOptions)> = if let Some(tls) = tls {
        match ServerName::try_from(hostname) {
            Ok(name) => {
                // Assert it's actually a DNS name, since rustls no like IPs
                if !matches!(name, ServerName::DnsName(_)) {
                    return Err(Error::TlsWithNonHostnameError { kind: hostname.into() });
                }
                Some((name, tls))
            },
            Err(err) => {
                return Err(Error::IllegalServerName { raw: hostname.into(), err });
            },
        }
    } else {
        None
    };

    // Attempt to open the TCP server
    let listener: TcpListener = match TcpListener::bind(socket_addr).await {
        Ok(listener) => listener,
        Err(err) => {
            return Err(Error::ListenerCreateError { address: socket_addr, err });
        },
    };

    // Now match on what to do
    if let Some(proxy_cfg) = &context.proxy.forward {
        // Open the relevant client
        let client: RemoteClient = match proxy_cfg.protocol {
            ProxyProtocol::Socks5 => {
                // Attempt to open the socks 5 client
                match Socks5Client::new(proxy_cfg.address.to_string(), None).await {
                    Ok(client) => RemoteClient::Socks5(client, proxy_cfg.address.clone()),
                    Err(err) => {
                        return Err(Error::Socks5CreateError { address: proxy_cfg.address.clone(), err });
                    },
                }
            },

            ProxyProtocol::Socks6 => {
                // Attempt to open the socks 6 client
                match Socks6Client::new(proxy_cfg.address.to_string(), None).await {
                    Ok(client) => RemoteClient::Socks6(client, proxy_cfg.address.clone()),
                    Err(err) => {
                        return Err(Error::Socks6CreateError { address: proxy_cfg.address.clone(), err });
                    },
                }
            },
        };

        // If that was successfull, return the future
        Ok(path_server(context.node_config_path.clone(), listener, client, socket_addr, remote_addr, tls))
    } else {
        // Otherwise, just pass the address as 'to-be-connected'
        Ok(path_server(context.node_config_path.clone(), listener, RemoteClient::Direct, socket_addr, remote_addr, tls))
    }
}



/// Serves the proxying service on the given path, to the given address.
///
/// # Arguments
/// - `node_config_path`: Path to the `node.yml` file where we pull the certificates location from.
/// - `listener`: The TcpListener with which we serve.
/// - `client`: The proxied SOCKS6-client / unproxied TcpClient to connect to remote hosts with.
/// - `socket_addr`: The SocketAddress on which to serve.
/// - `address`: The address to redirect the traffic to.
/// - `tls`: If given, adds TLS encryption to the remote host with the given address.
///
/// # Returns
/// Never, ideally.
///
/// # Errors
/// This function does not error directly, but instead write errors to stderr (using the `log` crate) and then returns.
pub async fn path_server(
    node_config_path: PathBuf,
    listener: TcpListener,
    client: RemoteClient,
    socket_addr: SocketAddr,
    address: Url,
    tls: Option<(ServerName, NewPathRequestTlsOptions)>,
) -> Never {
    info!("Initiated new path ':{}' to '{}'", socket_addr, address);
    loop {
        // Wait for the next connection
        debug!(":{}->{}: Ready for new connection", socket_addr.port(), address);
        let (mut iconn, client_addr): (TcpStream, SocketAddr) = match listener.accept().await {
            Ok(res) => res,
            Err(err) => {
                error!(":{}->{}: Failed to accept incoming request: {}", socket_addr.port(), address, err);
                continue;
            },
        };
        debug!(":{}->{}: Got new connection from '{}'", socket_addr.port(), address, client_addr);

        // Now we establish a new connection to the remote host
        let addr: String = format!("{}:{}", address.domain().unwrap(), address.port().unwrap());
        debug!("Connecting to '{}'...", addr);
        let mut oconn: TcpStream = match client.connect(&addr).await {
            Ok(oconn) => oconn,
            Err(err) => {
                error!(":{}->{}: Failed to connect to remote '{}': {}", socket_addr.port(), address, addr, err);
                continue;
            },
        };

        // Now switch on whether we're using TLS or not.
        if let Some((domain, tls)) = &tls {
            // Use TLS
            debug!(":{}->{}: Setting up TLS for location '{}'...", socket_addr.port(), address, tls.location);

            // Load the node config file
            let node_config: NodeConfig = match NodeConfig::from_path(&node_config_path) {
                Ok(config) => config,
                Err(err) => {
                    error!(":{}->{}: Failed to load NodeConfig file: {}", socket_addr.port(), address, err);
                    std::process::exit(1);
                },
            };

            // Load the certificate path
            let cert_path: &Path = match &node_config.node {
                NodeSpecificConfig::Central(node) => &node.paths.certs,
                NodeSpecificConfig::Worker(node) => &node.paths.certs,
                NodeSpecificConfig::Proxy(node) => &node.paths.certs,
            };

            // Load the root CA certificate file
            let ca_path: PathBuf = cert_path.join(&tls.location).join("ca.pem");
            let ca: RootCertStore = match load_certstore(&ca_path) {
                Ok(store) => store,
                Err(err) => {
                    error!(
                        ":{}->{}: Failed to load root certificate '{}' for location '{}': {}",
                        socket_addr.port(),
                        address,
                        ca_path.display(),
                        tls.location,
                        err
                    );
                    continue;
                },
            };

            // If any, also load the client file
            let client: Option<(PathBuf, Vec<Certificate>, PrivateKey)> = if tls.use_client_auth {
                debug!(":{}->{}: Adding client certificate...", socket_addr.port(), address);
                let client_path: PathBuf = cert_path.join(&tls.location).join("client-id.pem");
                match load_identity(&client_path) {
                    Ok((certs, key)) => Some((client_path, certs, key)),
                    Err(err) => {
                        error!(
                            ":{}->{}: Failed to load client identity file '{}' for location '{}': {}",
                            socket_addr.port(),
                            address,
                            client_path.display(),
                            tls.location,
                            err
                        );
                        continue;
                    },
                }
            } else {
                None
            };

            // Create a ClientConfig with that
            let config: ConfigBuilder<_, _> = ClientConfig::builder().with_safe_defaults().with_root_certificates(ca);
            let config: ClientConfig = if let Some((path, certs, key)) = client {
                match config.with_client_auth_cert(certs, key) {
                    Ok(config) => config,
                    Err(err) => {
                        error!(
                            ":{}->{}: Failed to build client config from '{}' and '{}': {}",
                            socket_addr.port(),
                            address,
                            ca_path.display(),
                            path.display(),
                            err
                        );
                        continue;
                    },
                }
            } else {
                config.with_no_client_auth()
            };

            // We can now wrap the outgoing stream in a TLS client stream.
            debug!(":{}->{}: Negotiating TLS...", socket_addr.port(), address);
            let connector: TlsConnector = TlsConnector::from(Arc::new(config));
            let mut oconn: TlsStream<TcpStream> = match connector.connect(domain.clone(), oconn).await {
                Ok(oconn) => oconn,
                Err(err) => {
                    error!(":{}->{}: Failed to start a TLS connection with '{}': {}", socket_addr.port(), address, addr, err);
                    continue;
                },
            };

            // For the remainder of this session, simply copy the TCP stream both ways
            debug!(":{}->{}: Bidirectional link started", socket_addr.port(), address);
            if let Err(err) = tokio::io::copy_bidirectional(&mut iconn, &mut oconn).await {
                error!(":{}->{}: Bidirectional link failed: {}", socket_addr.port(), address, err);
                continue;
            }
            debug!(":{}->{}: Bidirectional link completed", socket_addr.port(), address);
        } else {
            // Simple TCP

            // For the remainder of this session, simply copy the TCP stream both ways
            debug!(":{}->{}: Bidirectional link started", socket_addr.port(), address);
            if let Err(err) = tokio::io::copy_bidirectional(&mut iconn, &mut oconn).await {
                error!(":{}->{}: Bidirectional link failed: {}", socket_addr.port(), address, err);
                continue;
            }
            debug!(":{}->{}: Bidirectional link completed", socket_addr.port(), address);
        }
    }
}