]> Repositorios git - scryer-prolog.git/commitdiff
implement tls_client_negotiate/3 for explicit negotiation
authorMarkus Triska <[email protected]>
Sun, 5 Dec 2021 16:11:38 +0000 (17:11 +0100)
committerMarkus Triska <[email protected]>
Sun, 5 Dec 2021 16:39:22 +0000 (17:39 +0100)
src/clause_types.rs
src/machine/system_calls.rs

index 5419eb85b58c8cba66dc8b1ba92be68e3fc67b0f..cb65486769be0779b71fc9583e9262f4fb37b54e 100644 (file)
@@ -274,6 +274,7 @@ pub(crate) enum SystemClauseType {
     SocketServerAccept,
     SocketServerClose,
     TLSAcceptClient,
+    TLSClientConnect,
     Succeed,
     TermAttributedVariables,
     TermVariables,
@@ -565,6 +566,7 @@ impl SystemClauseType {
             &SystemClauseType::SocketServerAccept => clause_name!("$socket_server_accept"),
             &SystemClauseType::SocketServerClose => clause_name!("$socket_server_close"),
             &SystemClauseType::TLSAcceptClient => clause_name!("$tls_accept_client"),
+            &SystemClauseType::TLSClientConnect => clause_name!("$tls_client_connect"),
             &SystemClauseType::Succeed => clause_name!("$succeed"),
             &SystemClauseType::TermAttributedVariables => {
                 clause_name!("$term_attributed_variables")
@@ -747,6 +749,7 @@ impl SystemClauseType {
             ("$socket_server_accept", 7) => Some(SystemClauseType::SocketServerAccept),
             ("$socket_server_close", 1) => Some(SystemClauseType::SocketServerClose),
             ("$tls_accept_client", 4) => Some(SystemClauseType::TLSAcceptClient),
+            ("$tls_client_connect", 3) => Some(SystemClauseType::TLSClientConnect),
             ("$store_global_var", 2) => Some(SystemClauseType::StoreGlobalVar),
             ("$store_backtrackable_global_var", 2) => {
                 Some(SystemClauseType::StoreBacktrackableGlobalVar)
index 960cc36196cae2f41dbf5742289bb3d9ca8c3d06..0505e13831c9749f42a006d59898886a8b11e529 100644 (file)
@@ -4176,45 +4176,7 @@ impl MachineState {
                     Ok(tcp_stream) => {
                         let socket_addr = clause_name!(socket_addr, self.atom_tbl);
 
-                        let mut stream = {
-                            let tls = match self.store(self.deref(self[temp_v!(8)])) {
-                                Addr::Con(h) if self.heap.atom_at(h) => {
-                                    if let HeapCellValue::Atom(ref atom, _) = &self.heap[h] {
-                                        atom.as_str()
-                                    } else {
-                                        unreachable!()
-                                    }
-                                }
-                                _ => {
-                                    unreachable!()
-                                }
-                            };
-
-                            match tls {
-                                "false" => Stream::from_tcp_stream(socket_addr, tcp_stream),
-                                "true" => {
-                                    let connector = TlsConnector::new().unwrap();
-                                    let stream = Stream::from_tcp_stream(socket_addr, tcp_stream);
-                                    let stream =
-                                        match connector.connect(socket_atom.as_str(), stream) {
-                                            Ok(tls_stream) => tls_stream,
-                                            Err(_) => {
-                                                return Err(self.open_permission_error(
-                                                    addr,
-                                                    "socket_client_open",
-                                                    3,
-                                                ));
-                                            }
-                                        };
-
-                                    let addr = clause_name!("TLS".to_string(), self.atom_tbl);
-                                    Stream::from_tls_stream(addr, stream)
-                                }
-                                _ => {
-                                    unreachable!()
-                                }
-                            }
-                        };
+                        let mut stream = Stream::from_tcp_stream(socket_addr, tcp_stream);
 
                         *stream.options_mut() = options;
 
@@ -4418,6 +4380,37 @@ impl MachineState {
                     }
                 }
             }
+            &SystemClauseType::TLSClientConnect => {
+                let hostname = self.heap_pstr_iter(self[temp_v!(1)]).to_string();
+
+                let stream0 = self.get_stream_or_alias(
+                    self[temp_v!(2)],
+                    &indices.stream_aliases,
+                    "tls_client_negotiate",
+                    3,
+                )?;
+
+                let connector = TlsConnector::new().unwrap();
+                let stream =
+                    match connector.connect(&hostname, stream0) {
+                        Ok(tls_stream) => tls_stream,
+                        Err(_) => {
+                            return Err(self.open_permission_error(
+                                self[temp_v!(1)],
+                                "tls_client_negotiate",
+                                3,
+                            ));
+                        }
+                    };
+
+                let addr = clause_name!("TLS".to_string(), self.atom_tbl);
+                let stream = Stream::from_tls_stream(addr, stream);
+                indices.streams.insert(stream.clone());
+
+                let stream = self.heap.to_unifiable(HeapCellValue::Stream(stream));
+                let stream_addr = self.store(self.deref(self[temp_v!(3)]));
+                self.bind(stream_addr.as_var().unwrap(), stream);
+            }
             &SystemClauseType::TLSAcceptClient => {
                 let pkcs12 = self.string_encoding_bytes(1, "octet");
                 let password = self.heap_pstr_iter(self[temp_v!(2)]).to_string();