juniper_codegen/graphql_enum/
derive.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//! Code generation for `#[derive(GraphQLEnum)]` macro.

use std::collections::HashSet;

use proc_macro2::TokenStream;
use quote::ToTokens as _;
use syn::{ext::IdentExt as _, parse_quote, spanned::Spanned};

use crate::common::{diagnostic, rename, scalar, SpanContainer};

use super::{ContainerAttr, Definition, ValueDefinition, VariantAttr};

/// [`diagnostic::Scope`] of errors for `#[derive(GraphQLEnum)]` macro.
const ERR: diagnostic::Scope = diagnostic::Scope::EnumDerive;

/// Expands `#[derive(GraphQLEnum)]` macro into generated code.
pub(crate) fn expand(input: TokenStream) -> syn::Result<TokenStream> {
    let ast = syn::parse2::<syn::DeriveInput>(input)?;
    let attr = ContainerAttr::from_attrs("graphql", &ast.attrs)?;

    let data = if let syn::Data::Enum(data) = &ast.data {
        data
    } else {
        return Err(ERR.custom_error(ast.span(), "can only be derived on enums"));
    };

    let mut has_ignored_variants = false;
    let renaming = attr
        .rename_values
        .map(SpanContainer::into_inner)
        .unwrap_or(rename::Policy::ScreamingSnakeCase);
    let values = data
        .variants
        .iter()
        .filter_map(|v| {
            parse_value(v, renaming).or_else(|| {
                has_ignored_variants = true;
                None
            })
        })
        .collect::<Vec<_>>();

    diagnostic::abort_if_dirty();

    if values.is_empty() {
        return Err(ERR.custom_error(
            data.variants.span(),
            "expected at least 1 non-ignored enum variant",
        ));
    }

    let unique_values = values.iter().map(|v| &v.name).collect::<HashSet<_>>();
    if unique_values.len() != values.len() {
        return Err(ERR.custom_error(
            data.variants.span(),
            "expected all GraphQL enum values to have unique names",
        ));
    }

    let name = attr
        .name
        .clone()
        .map(SpanContainer::into_inner)
        .unwrap_or_else(|| ast.ident.unraw().to_string())
        .into_boxed_str();
    if !attr.is_internal && name.starts_with("__") {
        ERR.no_double_underscore(
            attr.name
                .as_ref()
                .map(SpanContainer::span_ident)
                .unwrap_or_else(|| ast.ident.span()),
        );
    }

    let context = attr
        .context
        .map_or_else(|| parse_quote! { () }, SpanContainer::into_inner);

    let scalar = scalar::Type::parse(attr.scalar.as_deref(), &ast.generics);

    diagnostic::abort_if_dirty();

    let definition = Definition {
        ident: ast.ident,
        generics: ast.generics,
        name,
        description: attr.description.map(SpanContainer::into_inner),
        context,
        scalar,
        values,
        has_ignored_variants,
    };

    Ok(definition.into_token_stream())
}

/// Parses a [`ValueDefinition`] from the given Rust enum variant definition.
///
/// Returns [`None`] if the parsing fails, or the enum variant is ignored.
fn parse_value(v: &syn::Variant, renaming: rename::Policy) -> Option<ValueDefinition> {
    let attr = VariantAttr::from_attrs("graphql", &v.attrs)
        .map_err(diagnostic::emit_error)
        .ok()?;

    if attr.ignore.is_some() {
        return None;
    }

    if !v.fields.is_empty() {
        err_variant_with_fields(&v.fields)?;
    }

    let name = attr
        .name
        .map_or_else(
            || renaming.apply(&v.ident.unraw().to_string()),
            SpanContainer::into_inner,
        )
        .into_boxed_str();

    Some(ValueDefinition {
        ident: v.ident.clone(),
        name,
        description: attr.description.map(SpanContainer::into_inner),
        deprecated: attr.deprecated.map(SpanContainer::into_inner),
    })
}

/// Emits "no fields allowed for non-ignored variants" [`syn::Error`] pointing
/// to the given `span`.
pub fn err_variant_with_fields<T, S: Spanned>(span: &S) -> Option<T> {
    ERR.emit_custom(span.span(), "no fields allowed for non-ignored variants");
    None
}