render_server.rs 8.45 KiB
Newer Older
Igor Aleksanov's avatar
Igor Aleksanov committed
use super::RpcDescription;
Maciej Hirsz's avatar
Maciej Hirsz committed
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, quote_spanned};
use std::collections::HashSet;
Igor Aleksanov's avatar
Igor Aleksanov committed

impl RpcDescription {
	pub(super) fn render_server(&self) -> Result<TokenStream2, syn::Error> {
		let trait_name = quote::format_ident!("{}Server", &self.trait_def.ident);

		let method_impls = self.render_methods()?;
		let into_rpc_impl = self.render_into_rpc()?;

		let async_trait = self.jrps_server_item(quote! { types::__reexports::async_trait });
Igor Aleksanov's avatar
Igor Aleksanov committed

		// Doc-comment to be associated with the server.
		let doc_comment = format!("Server trait implementation for the `{}` RPC API.", &self.trait_def.ident);

		let trait_impl = quote! {
			#[#async_trait]
			#[doc = #doc_comment]
			pub trait #trait_name: Sized + Send + Sync + 'static {
				#method_impls
				#into_rpc_impl
			}
		};

		Ok(trait_impl)
	}

	fn render_methods(&self) -> Result<TokenStream2, syn::Error> {
		let methods = self.methods.iter().map(|method| &method.signature);

		let subscription_sink_ty = self.jrps_server_item(quote! { SubscriptionSink });
		let subscriptions = self.subscriptions.iter().cloned().map(|mut sub| {
			// Add `SubscriptionSink` as the second input parameter to the signature.
			let subscription_sink: syn::FnArg = syn::parse_quote!(subscription_sink: #subscription_sink_ty);
			sub.signature.sig.inputs.insert(1, subscription_sink);
			sub.signature
		});

		Ok(quote! {
			#(#methods)*
			#(#subscriptions)*
		})
	}

	fn render_into_rpc(&self) -> Result<TokenStream2, syn::Error> {
		let jrps_error = self.jrps_server_item(quote! { types::Error });
Igor Aleksanov's avatar
Igor Aleksanov committed
		let rpc_module = self.jrps_server_item(quote! { RpcModule });

Maciej Hirsz's avatar
Maciej Hirsz committed
		let mut registered = HashSet::new();
		let mut errors = Vec::new();
		let mut check_name = |name: String, span: Span| {
			if registered.contains(&name) {
				let message = format!("{:?} is already defined", name);
				errors.push(quote_spanned!(span => compile_error!(#message);));
			} else {
				registered.insert(name);
			}
		};

		let methods = self
			.methods
			.iter()
			.map(|method| {
				// Rust method to invoke (e.g. `self.<foo>(...)`).
				let rust_method_name = &method.signature.sig.ident;
				// Name of the RPC method (e.g. `foo_makeSpam`).
				let rpc_method_name = self.rpc_identifier(&method.name);
				// `parsing` is the code associated with parsing structure from the
				// provided `RpcParams` object.
				// `params_seq` is the comma-delimited sequence of parametsrs.
				let is_method = true;
				let (parsing, params_seq) = self.render_params_decoding(&method.params, is_method);

				check_name(rpc_method_name.clone(), rust_method_name.span());

				if method.signature.sig.asyncness.is_some() {
					quote! {
						rpc.register_async_method(#rpc_method_name, |params, context| {
							let fut = async move {
								#parsing
								Ok(context.as_ref().#rust_method_name(#params_seq).await)
							};
							Box::pin(fut)
						})?;
					}
				} else {
					quote! {
						rpc.register_method(#rpc_method_name, |params, context| {
Igor Aleksanov's avatar
Igor Aleksanov committed
							#parsing
Maciej Hirsz's avatar
Maciej Hirsz committed
							Ok(context.#rust_method_name(#params_seq))
						})?;
					}
Igor Aleksanov's avatar
Igor Aleksanov committed
				}
Maciej Hirsz's avatar
Maciej Hirsz committed
			})
			.collect::<Vec<_>>();

		let subscriptions = self
			.subscriptions
			.iter()
			.map(|sub| {
				// Rust method to invoke (e.g. `self.<foo>(...)`).
				let rust_method_name = &sub.signature.sig.ident;
				// Name of the RPC method to subscribe to (e.g. `foo_sub`).
				let rpc_sub_name = self.rpc_identifier(&sub.name);
				// Name of the RPC method to unsubscribe (e.g. `foo_sub`).
				let rpc_unsub_name = self.rpc_identifier(&sub.unsub_method);
				// `parsing` is the code associated with parsing structure from the
				// provided `RpcParams` object.
				// `params_seq` is the comma-delimited sequence of parametsrs.
				let is_method = false;
				let (parsing, params_seq) = self.render_params_decoding(&sub.params, is_method);

				check_name(rpc_sub_name.clone(), rust_method_name.span());
				check_name(rpc_unsub_name.clone(), rust_method_name.span());

Igor Aleksanov's avatar
Igor Aleksanov committed
				quote! {
Maciej Hirsz's avatar
Maciej Hirsz committed
					rpc.register_subscription(#rpc_sub_name, #rpc_unsub_name, |params, sink, context| {
Igor Aleksanov's avatar
Igor Aleksanov committed
						#parsing
Maciej Hirsz's avatar
Maciej Hirsz committed
						Ok(context.as_ref().#rust_method_name(sink, #params_seq))
Igor Aleksanov's avatar
Igor Aleksanov committed
					})?;
				}
Maciej Hirsz's avatar
Maciej Hirsz committed
			})
			.collect::<Vec<_>>();
Igor Aleksanov's avatar
Igor Aleksanov committed

		let doc_comment = "Collects all the methods and subscriptions defined in the trait \
								and adds them into a single `RpcModule`.";

		Ok(quote! {
			#[doc = #doc_comment]
Maciej Hirsz's avatar
Maciej Hirsz committed
			fn into_rpc(self) -> #rpc_module<Self> {
				let inner = move || -> Result<#rpc_module<Self>, #jrps_error> {
					let mut rpc = #rpc_module::new(self);

					#(#errors)*
					#(#methods)*
					#(#subscriptions)*
Maciej Hirsz's avatar
Maciej Hirsz committed
					Ok(rpc)
				};
Maciej Hirsz's avatar
Maciej Hirsz committed
				inner().expect("RPC macro method names should never conflict")
Igor Aleksanov's avatar
Igor Aleksanov committed
			}
		})
	}

	fn render_params_decoding(
		&self,
		params: &[(syn::PatIdent, syn::Type)],
		is_method: bool,
	) -> (TokenStream2, TokenStream2) {
		if params.is_empty() {
			return (TokenStream2::default(), TokenStream2::default());
		}

		// Implementations for `.map_err(...)?` and `.ok_or(...)?` with respect to the expected
		// error return type.
		let (err, map_err_impl, ok_or_impl) = if is_method {
			// For methods, we return `CallError`.
			let jrps_call_error = self.jrps_server_item(quote! { types::CallError });
Igor Aleksanov's avatar
Igor Aleksanov committed
			let err = quote! { #jrps_call_error::InvalidParams };
			let map_err = quote! { .map_err(|_| #jrps_call_error::InvalidParams)? };
			let ok_or = quote! { .ok_or(#jrps_call_error::InvalidParams)? };
			(err, map_err, ok_or)
		} else {
			// For subscriptions, we return `Error`.
			// Note that while `Error` can be constructed from `CallError`, we should not do it,
			// because it would be an abuse of the error type semantics.
Igor Aleksanov's avatar
Igor Aleksanov committed
			// Instead, we use suitable top-level error variants.
			let jrps_error = self.jrps_server_item(quote! { types::Error });
Igor Aleksanov's avatar
Igor Aleksanov committed
			let err = quote! { #jrps_error::Request("Required paramater missing".into()) };
			let map_err = quote! { .map_err(|err| #jrps_error::ParseError(err))? };
			let ok_or = quote! { .ok_or(#jrps_error::Request("Required paramater missing".into()))? };
			(err, map_err, ok_or)
		};

		let serde_json = self.jrps_server_item(quote! { types::__reexports::serde_json });
Igor Aleksanov's avatar
Igor Aleksanov committed

		// Parameters encoded as a tuple (to be parsed from array).
		let (params_fields_seq, params_types_seq): (Vec<_>, Vec<_>) = params.iter().cloned().unzip();
		let params_types = quote! { (#(#params_types_seq),*) };
		let params_fields = quote! { (#(#params_fields_seq),*) };

		// Code to decode sequence of parameters from a JSON array.
		let decode_array = {
			let decode_fields = params.iter().enumerate().map(|(id, (name, ty))| {
				if is_option(ty) {
					quote! {
						let #name = arr
							.get(#id)
							.cloned()
							.map(#serde_json::from_value)
							.transpose()
							#map_err_impl;
					}
				} else {
					quote! {
						let #name = arr
							.get(#id)
							.cloned()
							.map(#serde_json::from_value)
							#ok_or_impl
							#map_err_impl;
					}
				}
			});

			quote! {
				#(#decode_fields);*
				#params_fields
			}
		};

		// Code to decode sequence of parameters from a JSON object (aka map).
		let decode_map = {
			let decode_fields = params.iter().map(|(name, ty)| {
				let name_str = name.ident.to_string();
				if is_option(ty) {
					quote! {
						let #name = obj
							.get(#name_str)
							.cloned()
							.map(#serde_json::from_value)
							.transpose()
							#map_err_impl;
					}
				} else {
					quote! {
						let #name = obj
							.get(#name_str)
							.cloned()
							.map(#serde_json::from_value)
							#ok_or_impl
							#map_err_impl;
					}
				}
			});

			quote! {
				#(#decode_fields);*
				#params_fields
			}
		};

		// Code to decode single parameter from a JSON primitive.
		let decode_single = if params.len() == 1 {
			quote! {
				#serde_json::from_value(json)
				#map_err_impl
			}
		} else {
			quote! { return Err(#err);}
		};

		// Parsing of `serde_json::Value`.
		let parsing = quote! {
			let json: #serde_json::Value = params.parse()?;
			let #params_fields: #params_types = match json {
				#serde_json::Value::Null => return Err(#err),
				#serde_json::Value::Array(arr) => {
					#decode_array
				}
				#serde_json::Value::Object(obj) => {
					#decode_map
				}
				_ => {
					#decode_single
				}
			};
		};

		let seq = quote! {
			#(#params_fields_seq),*
		};

		(parsing, seq)
	}
}

/// Checks whether provided type is an `Option<...>`.
fn is_option(ty: &syn::Type) -> bool {
	if let syn::Type::Path(path) = ty {
		// TODO: Probably not the best way to check whether type is an `Option`.
		if path.path.segments.iter().any(|seg| seg.ident == "Option") {
			return true;
		}
	}

	false
}