|
| 1 | +//! Bidirectional protocol methods |
| 2 | +
|
| 3 | +use std::{ |
| 4 | + io::{self, BufRead, Write}, |
| 5 | + sync::Arc, |
| 6 | +}; |
| 7 | + |
| 8 | +use base_db::SourceDatabase; |
| 9 | +use paths::AbsPath; |
| 10 | +use span::{FileId, Span}; |
| 11 | + |
| 12 | +use crate::{ |
| 13 | + Codec, ProcMacro, ProcMacroKind, ServerError, |
| 14 | + bidirectional_protocol::msg::{ |
| 15 | + Envelope, ExpandMacro, ExpandMacroData, ExpnGlobals, Kind, Payload, Request, RequestId, |
| 16 | + Response, SubRequest, SubResponse, |
| 17 | + }, |
| 18 | + legacy_protocol::{ |
| 19 | + SpanMode, |
| 20 | + msg::{ |
| 21 | + FlatTree, ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map, |
| 22 | + serialize_span_data_index_map, |
| 23 | + }, |
| 24 | + }, |
| 25 | + process::ProcMacroServerProcess, |
| 26 | + transport::codec::{json::JsonProtocol, postcard::PostcardProtocol}, |
| 27 | + version, |
| 28 | +}; |
| 29 | + |
| 30 | +pub mod msg; |
| 31 | + |
| 32 | +pub trait ClientCallbacks { |
| 33 | + fn handle_sub_request(&mut self, id: u64, req: SubRequest) -> Result<SubResponse, ServerError>; |
| 34 | +} |
| 35 | + |
| 36 | +pub fn run_conversation<C: Codec>( |
| 37 | + writer: &mut dyn Write, |
| 38 | + reader: &mut dyn BufRead, |
| 39 | + buf: &mut C::Buf, |
| 40 | + id: RequestId, |
| 41 | + initial: Payload, |
| 42 | + callbacks: &mut dyn ClientCallbacks, |
| 43 | +) -> Result<Payload, ServerError> { |
| 44 | + let msg = Envelope { id, kind: Kind::Request, payload: initial }; |
| 45 | + let encoded = C::encode(&msg).map_err(wrap_encode)?; |
| 46 | + C::write(writer, &encoded).map_err(wrap_io("failed to write initial request"))?; |
| 47 | + |
| 48 | + loop { |
| 49 | + let maybe_buf = C::read(reader, buf).map_err(wrap_io("failed to read message"))?; |
| 50 | + let Some(b) = maybe_buf else { |
| 51 | + return Err(ServerError { |
| 52 | + message: "proc-macro server closed the stream".into(), |
| 53 | + io: Some(Arc::new(io::Error::new(io::ErrorKind::UnexpectedEof, "closed"))), |
| 54 | + }); |
| 55 | + }; |
| 56 | + |
| 57 | + let msg: Envelope = C::decode(b).map_err(wrap_decode)?; |
| 58 | + |
| 59 | + if msg.id != id { |
| 60 | + return Err(ServerError { |
| 61 | + message: format!("unexpected message id {}, expected {}", msg.id, id), |
| 62 | + io: None, |
| 63 | + }); |
| 64 | + } |
| 65 | + |
| 66 | + match (msg.kind, msg.payload) { |
| 67 | + (Kind::SubRequest, Payload::SubRequest(sr)) => { |
| 68 | + let resp = callbacks.handle_sub_request(id, sr)?; |
| 69 | + let reply = |
| 70 | + Envelope { id, kind: Kind::SubResponse, payload: Payload::SubResponse(resp) }; |
| 71 | + let encoded = C::encode(&reply).map_err(wrap_encode)?; |
| 72 | + C::write(writer, &encoded).map_err(wrap_io("failed to write sub-response"))?; |
| 73 | + } |
| 74 | + (Kind::Response, payload) => { |
| 75 | + return Ok(payload); |
| 76 | + } |
| 77 | + (kind, payload) => { |
| 78 | + return Err(ServerError { |
| 79 | + message: format!( |
| 80 | + "unexpected message kind {:?} with payload {:?}", |
| 81 | + kind, payload |
| 82 | + ), |
| 83 | + io: None, |
| 84 | + }); |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +fn wrap_io(msg: &'static str) -> impl Fn(io::Error) -> ServerError { |
| 91 | + move |err| ServerError { message: msg.into(), io: Some(Arc::new(err)) } |
| 92 | +} |
| 93 | + |
| 94 | +fn wrap_encode(err: io::Error) -> ServerError { |
| 95 | + ServerError { message: "failed to encode message".into(), io: Some(Arc::new(err)) } |
| 96 | +} |
| 97 | + |
| 98 | +fn wrap_decode(err: io::Error) -> ServerError { |
| 99 | + ServerError { message: "failed to decode message".into(), io: Some(Arc::new(err)) } |
| 100 | +} |
| 101 | + |
| 102 | +pub(crate) fn version_check(srv: &ProcMacroServerProcess) -> Result<u32, ServerError> { |
| 103 | + let request = Payload::Request(Request::ApiVersionCheck {}); |
| 104 | + |
| 105 | + struct NoCallbacks; |
| 106 | + impl ClientCallbacks for NoCallbacks { |
| 107 | + fn handle_sub_request( |
| 108 | + &mut self, |
| 109 | + _id: u64, |
| 110 | + _req: SubRequest, |
| 111 | + ) -> Result<SubResponse, ServerError> { |
| 112 | + Err(ServerError { message: "sub-request not supported here".into(), io: None }) |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + let mut callbacks = NoCallbacks; |
| 117 | + |
| 118 | + let response_payload = |
| 119 | + run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?; |
| 120 | + |
| 121 | + match response_payload { |
| 122 | + Payload::Response(Response::ApiVersionCheck(version)) => Ok(version), |
| 123 | + other => { |
| 124 | + Err(ServerError { message: format!("unexpected response: {:?}", other), io: None }) |
| 125 | + } |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +/// Enable support for rust-analyzer span mode if the server supports it. |
| 130 | +pub(crate) fn enable_rust_analyzer_spans( |
| 131 | + srv: &ProcMacroServerProcess, |
| 132 | +) -> Result<SpanMode, ServerError> { |
| 133 | + let request = |
| 134 | + Payload::Request(Request::SetConfig(ServerConfig { span_mode: SpanMode::RustAnalyzer })); |
| 135 | + |
| 136 | + struct NoCallbacks; |
| 137 | + impl ClientCallbacks for NoCallbacks { |
| 138 | + fn handle_sub_request( |
| 139 | + &mut self, |
| 140 | + _id: u64, |
| 141 | + _req: SubRequest, |
| 142 | + ) -> Result<SubResponse, ServerError> { |
| 143 | + Err(ServerError { message: "sub-request not supported here".into(), io: None }) |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + let mut callbacks = NoCallbacks; |
| 148 | + |
| 149 | + let response_payload = |
| 150 | + run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?; |
| 151 | + |
| 152 | + match response_payload { |
| 153 | + Payload::Response(Response::SetConfig(ServerConfig { span_mode })) => Ok(span_mode), |
| 154 | + _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }), |
| 155 | + } |
| 156 | +} |
| 157 | + |
| 158 | +/// Finds proc-macros in a given dynamic library. |
| 159 | +pub(crate) fn find_proc_macros( |
| 160 | + srv: &ProcMacroServerProcess, |
| 161 | + dylib_path: &AbsPath, |
| 162 | +) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> { |
| 163 | + let request = |
| 164 | + Payload::Request(Request::ListMacros { dylib_path: dylib_path.to_path_buf().into() }); |
| 165 | + |
| 166 | + struct NoCallbacks; |
| 167 | + impl ClientCallbacks for NoCallbacks { |
| 168 | + fn handle_sub_request( |
| 169 | + &mut self, |
| 170 | + _id: u64, |
| 171 | + _req: SubRequest, |
| 172 | + ) -> Result<SubResponse, ServerError> { |
| 173 | + Err(ServerError { message: "sub-request not supported here".into(), io: None }) |
| 174 | + } |
| 175 | + } |
| 176 | + |
| 177 | + let mut callbacks = NoCallbacks; |
| 178 | + |
| 179 | + let response_payload = |
| 180 | + run_bidirectional(srv, (0, Kind::Request, request).into(), &mut callbacks)?; |
| 181 | + |
| 182 | + match response_payload { |
| 183 | + Payload::Response(Response::ListMacros(it)) => Ok(it), |
| 184 | + _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }), |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +pub(crate) fn expand( |
| 189 | + proc_macro: &ProcMacro, |
| 190 | + db: &dyn SourceDatabase, |
| 191 | + subtree: tt::SubtreeView<'_, Span>, |
| 192 | + attr: Option<tt::SubtreeView<'_, Span>>, |
| 193 | + env: Vec<(String, String)>, |
| 194 | + def_site: Span, |
| 195 | + call_site: Span, |
| 196 | + mixed_site: Span, |
| 197 | + current_dir: String, |
| 198 | +) -> Result<Result<tt::TopSubtree<span::SpanData<span::SyntaxContext>>, String>, crate::ServerError> |
| 199 | +{ |
| 200 | + let version = proc_macro.process.version(); |
| 201 | + let mut span_data_table = SpanDataIndexMap::default(); |
| 202 | + let def_site = span_data_table.insert_full(def_site).0; |
| 203 | + let call_site = span_data_table.insert_full(call_site).0; |
| 204 | + let mixed_site = span_data_table.insert_full(mixed_site).0; |
| 205 | + let task = Payload::Request(Request::ExpandMacro(Box::new(ExpandMacro { |
| 206 | + data: ExpandMacroData { |
| 207 | + macro_body: FlatTree::from_subtree(subtree, version, &mut span_data_table), |
| 208 | + macro_name: proc_macro.name.to_string(), |
| 209 | + attributes: attr |
| 210 | + .map(|subtree| FlatTree::from_subtree(subtree, version, &mut span_data_table)), |
| 211 | + has_global_spans: ExpnGlobals { |
| 212 | + serialize: version >= version::HAS_GLOBAL_SPANS, |
| 213 | + def_site, |
| 214 | + call_site, |
| 215 | + mixed_site, |
| 216 | + }, |
| 217 | + span_data_table: if proc_macro.process.rust_analyzer_spans() { |
| 218 | + serialize_span_data_index_map(&span_data_table) |
| 219 | + } else { |
| 220 | + Vec::new() |
| 221 | + }, |
| 222 | + }, |
| 223 | + lib: proc_macro.dylib_path.to_path_buf().into(), |
| 224 | + env, |
| 225 | + current_dir: Some(current_dir), |
| 226 | + }))); |
| 227 | + |
| 228 | + struct Callbacks<'de> { |
| 229 | + db: &'de dyn SourceDatabase, |
| 230 | + } |
| 231 | + impl<'db> ClientCallbacks for Callbacks<'db> { |
| 232 | + fn handle_sub_request( |
| 233 | + &mut self, |
| 234 | + _id: u64, |
| 235 | + req: SubRequest, |
| 236 | + ) -> Result<SubResponse, ServerError> { |
| 237 | + match req { |
| 238 | + SubRequest::SourceText { file_id, start, end } => { |
| 239 | + let file = FileId::from_raw(file_id); |
| 240 | + let text = self.db.file_text(file).text(self.db); |
| 241 | + |
| 242 | + let slice = text.get(start as usize..end as usize).map(|s| s.to_owned()); |
| 243 | + |
| 244 | + Ok(SubResponse::SourceTextResult { text: slice }) |
| 245 | + } |
| 246 | + } |
| 247 | + } |
| 248 | + } |
| 249 | + |
| 250 | + let mut callbacks = Callbacks { db }; |
| 251 | + |
| 252 | + let response_payload = |
| 253 | + run_bidirectional(&proc_macro.process, (0, Kind::Request, task).into(), &mut callbacks)?; |
| 254 | + |
| 255 | + match response_payload { |
| 256 | + Payload::Response(Response::ExpandMacro(it)) => Ok(it |
| 257 | + .map(|tree| { |
| 258 | + let mut expanded = FlatTree::to_subtree_resolved(tree, version, &span_data_table); |
| 259 | + if proc_macro.needs_fixup_change() { |
| 260 | + proc_macro.change_fixup_to_match_old_server(&mut expanded); |
| 261 | + } |
| 262 | + expanded |
| 263 | + }) |
| 264 | + .map_err(|msg| msg.0)), |
| 265 | + Payload::Response(Response::ExpandMacroExtended(it)) => Ok(it |
| 266 | + .map(|resp| { |
| 267 | + let mut expanded = FlatTree::to_subtree_resolved( |
| 268 | + resp.tree, |
| 269 | + version, |
| 270 | + &deserialize_span_data_index_map(&resp.span_data_table), |
| 271 | + ); |
| 272 | + if proc_macro.needs_fixup_change() { |
| 273 | + proc_macro.change_fixup_to_match_old_server(&mut expanded); |
| 274 | + } |
| 275 | + expanded |
| 276 | + }) |
| 277 | + .map_err(|msg| msg.0)), |
| 278 | + _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }), |
| 279 | + } |
| 280 | +} |
| 281 | + |
| 282 | +fn run_bidirectional( |
| 283 | + srv: &ProcMacroServerProcess, |
| 284 | + msg: Envelope, |
| 285 | + callbacks: &mut dyn ClientCallbacks, |
| 286 | +) -> Result<Payload, ServerError> { |
| 287 | + if let Some(server_error) = srv.exited() { |
| 288 | + return Err(server_error.clone()); |
| 289 | + } |
| 290 | + |
| 291 | + if srv.use_postcard() { |
| 292 | + srv.run_bidirectional::<PostcardProtocol>(msg.id, msg.payload, callbacks) |
| 293 | + } else { |
| 294 | + srv.run_bidirectional::<JsonProtocol>(msg.id, msg.payload, callbacks) |
| 295 | + } |
| 296 | +} |
0 commit comments