Skip to content

Commit

Permalink
optional macro parameter for adding Sync marker to returned Future type
Browse files Browse the repository at this point in the history
fixes #77

Sponsored-by: 49nord GmbH
  • Loading branch information
problame committed May 14, 2020
1 parent d2265ce commit 485bd23
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 21 deletions.
30 changes: 23 additions & 7 deletions src/args.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use proc_macro2::Span;
use syn::parse::{Error, Parse, ParseStream, Result};
use syn::spanned::Spanned;
use syn::Token;

#[derive(Copy, Clone)]
pub struct Args {
pub local: bool,
pub sync: bool,
}

mod kw {
Expand All @@ -21,16 +23,30 @@ impl Parse for Args {
}

fn try_parse(input: ParseStream) -> Result<Args> {
if input.peek(Token![?]) {
input.parse::<Token![?]>()?;
input.parse::<kw::Send>()?;
Ok(Args { local: true })
} else {
Ok(Args { local: false })
let mut send = false;
let mut sync = false;

let arg_list: syn::punctuated::Punctuated<syn::TypeParamBound, Token![+]>;
arg_list = input.parse_terminated(syn::TypeParamBound::parse)?;
for bound in arg_list.into_iter() {
let error = || Error::new(bound.span(), r#"only "?Send" and "Sync" are allowed"#);

let (modifier, path) = match &bound {
syn::TypeParamBound::Trait(syn::TraitBound { modifier, path, .. }) => (modifier, path),
_ => return Err(error()),
};
let ident = path.get_ident().ok_or_else(error)?;

match (modifier, ident.to_string().as_ref()) {
(syn::TraitBoundModifier::Maybe(_), "Send") => send = true,
(syn::TraitBoundModifier::None, "Sync") => sync = true,
_ => return Err(error()),
}
}
Ok(Args { local: send, sync })
}

fn error() -> Error {
let msg = "expected #[async_trait] or #[async_trait(?Send)]";
let msg = "expected #[async_trait], #[async_trait(?Send)]], #[async_trait(?Send, Sync)] or #[async_trait(Sync)]";
Error::new(Span::call_site(), msg)
}
33 changes: 20 additions & 13 deletions src/expand.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::args::Args;
use crate::lifetime::{has_async_lifetime, CollectLifetimes};
use crate::parse::Item;
use crate::receiver::{
Expand Down Expand Up @@ -55,7 +56,7 @@ impl Context<'_> {

type Supertraits = Punctuated<TypeParamBound, Token![+]>;

pub fn expand(input: &mut Item, is_local: bool) {
pub fn expand(input: &mut Item, args: Args) {
match input {
Item::Trait(input) => {
let context = Context::Trait {
Expand All @@ -71,10 +72,10 @@ pub fn expand(input: &mut Item, is_local: bool) {
let mut has_self = has_self_in_sig(sig);
if let Some(block) = block {
has_self |= has_self_in_block(block);
transform_block(context, sig, block, has_self, is_local);
transform_block(context, sig, block, has_self, args);
}
let has_default = method.default.is_some();
transform_sig(context, sig, has_self, has_default, is_local);
transform_sig(context, sig, has_self, has_default, args);
method.attrs.push(parse_quote!(#[must_use]));
}
}
Expand All @@ -92,8 +93,8 @@ pub fn expand(input: &mut Item, is_local: bool) {
if sig.asyncness.is_some() {
let block = &mut method.block;
let has_self = has_self_in_sig(sig) || has_self_in_block(block);
transform_block(context, sig, block, has_self, is_local);
transform_sig(context, sig, has_self, false, is_local);
transform_block(context, sig, block, has_self, args);
transform_sig(context, sig, has_self, false, args);
}
}
}
Expand All @@ -119,7 +120,7 @@ fn transform_sig(
sig: &mut Signature,
has_self: bool,
has_default: bool,
is_local: bool,
args: Args,
) {
sig.fn_token.span = sig.asyncness.take().unwrap().span;

Expand Down Expand Up @@ -195,7 +196,7 @@ fn transform_sig(
Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, &bound),
Context::Impl { .. } => true,
};
where_clause.predicates.push(if assume_bound || is_local {
where_clause.predicates.push(if assume_bound || args.local {
parse_quote!(Self: 'async_trait)
} else {
parse_quote!(Self: ::core::marker::#bound + 'async_trait)
Expand All @@ -220,10 +221,16 @@ fn transform_sig(
}
}

let bounds = if is_local {
quote!('async_trait)
} else {
quote!(::core::marker::Send + 'async_trait)
let bounds = {
let mut bounds = TokenStream::new();
bounds.extend(quote! {'async_trait});
if !args.local {
bounds.extend(quote! { + ::core::marker::Send});
}
if args.sync {
bounds.extend(quote! { + ::core::marker::Sync});
}
bounds
};

sig.output = parse_quote! {
Expand All @@ -248,7 +255,7 @@ fn transform_block(
sig: &mut Signature,
block: &mut Block,
has_self: bool,
is_local: bool,
macro_args: Args,
) {
if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
if block.stmts.len() == 1 && item.to_string() == ";" {
Expand Down Expand Up @@ -395,7 +402,7 @@ fn transform_block(
if has_self {
let (_, generics, _) = generics.split_for_impl();
let mut self_param: TypeParam = parse_quote!(AsyncTrait: ?Sized + #name #generics);
if !is_local {
if !macro_args.local {
self_param.bounds.extend(self_bound);
}
standalone
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,6 @@ use syn::parse_macro_input;
pub fn async_trait(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as Args);
let mut item = parse_macro_input!(input as Item);
expand(&mut item, args.local);
expand(&mut item, args);
TokenStream::from(quote!(#item))
}
16 changes: 16 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,22 @@ pub async fn test_object_no_send() {
object.f().await;
}

pub async fn test_sync_marker() {
#[async_trait(Sync)]
trait Interface {
async fn f(&self);
}

#[async_trait(Sync)]
impl Interface for Struct {
async fn f(&self) {}
}

let object = &Struct as &dyn Interface;
let _future_is_sync: std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Sync>> =
object.f();
}

#[async_trait]
pub unsafe trait UnsafeTrait {}

Expand Down

0 comments on commit 485bd23

Please sign in to comment.