Skip to content

Commit

Permalink
Trial an IncrementalValuesProvider.GroupBy combinator.
Browse files Browse the repository at this point in the history
  • Loading branch information
eiriktsarpalis committed Apr 10, 2024
1 parent b41e087 commit 029bf54
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Immutable;
using System.Diagnostics;

namespace TypeShape.SourceGenerator.Helpers;

public readonly struct TypeWithAttributeDeclarationContext
{
public required BaseTypeDeclarationSyntax DeclarationSyntax { get; init; }
public required ITypeSymbol TypeSymbol { get; init; }
public required SemanticModel SemanticModel { get; init; }
public required ImmutableArray<(BaseTypeDeclarationSyntax Syntax, SemanticModel Model)> Declarations { get; init; }
}

internal static partial class RoslynHelpers
Expand All @@ -21,22 +21,33 @@ public static IncrementalValuesProvider<TypeWithAttributeDeclarationContext> For
this SyntaxValueProvider provider, string attributeFullyQualifiedName,
Func<BaseTypeDeclarationSyntax, CancellationToken, bool> predicate)
{
string attributeName = ParseTypeName(SyntaxFactory.ParseName(attributeFullyQualifiedName), out int attributeArity);
NameSyntax attributeNameSyntax = SyntaxFactory.ParseName(attributeFullyQualifiedName);
string attributeName = GetTypeName(attributeNameSyntax, out int attributeArity);
ImmutableArray<string> attributeNamespace = GetNamespaceTokens(attributeNameSyntax);
string? attributeNameMinusSuffix = attributeName.EndsWith("Attribute", StringComparison.Ordinal) ? attributeName[..^"Attribute".Length] : null;
string? attributeNamespace = attributeFullyQualifiedName.LastIndexOf('.') is >= 0 and int i ? attributeFullyQualifiedName[..i] : null;

return provider.CreateSyntaxProvider(
predicate: (SyntaxNode node, CancellationToken token) => node is BaseTypeDeclarationSyntax typeDecl && IsAnnotatedTypeDeclaration(typeDecl, token),
transform: Transform)
.Where(ctx => ctx.DeclarationSyntax != null);
.Where(ctx => ctx.Type != null)
.GroupBy(
keySelector: value => value.Type,
resultSelector: static (key, values) =>
new TypeWithAttributeDeclarationContext
{
TypeSymbol = (ITypeSymbol)key!,
Declarations = values.Select(v => (v.Syntax, v.Model)).ToImmutableArray()
},

keyComparer: SymbolEqualityComparer.Default);

bool IsAnnotatedTypeDeclaration(BaseTypeDeclarationSyntax typeDecl, CancellationToken token)
{
foreach (AttributeListSyntax attributeList in typeDecl.AttributeLists)
{
foreach (AttributeSyntax attribute in attributeList.Attributes)
{
string name = ParseTypeName(attribute.Name, out int arity);
string name = GetTypeName(attribute.Name, out int arity);
if ((name == attributeName || name == attributeNameMinusSuffix) && arity == attributeArity)
{
return predicate(typeDecl, token);
Expand All @@ -47,7 +58,7 @@ bool IsAnnotatedTypeDeclaration(BaseTypeDeclarationSyntax typeDecl, Cancellation
return false;
}

TypeWithAttributeDeclarationContext Transform(GeneratorSyntaxContext ctx, CancellationToken token)
(ITypeSymbol Type, BaseTypeDeclarationSyntax Syntax, SemanticModel Model) Transform(GeneratorSyntaxContext ctx, CancellationToken token)
{
BaseTypeDeclarationSyntax typeDecl = (BaseTypeDeclarationSyntax)ctx.Node;
ITypeSymbol typeSymbol = ctx.SemanticModel.GetDeclaredSymbol(typeDecl, token)!;
Expand All @@ -57,16 +68,16 @@ TypeWithAttributeDeclarationContext Transform(GeneratorSyntaxContext ctx, Cancel
if (attrData.AttributeClass is INamedTypeSymbol attributeType &&
attributeType.Name == attributeName &&
attributeType.Arity == attributeArity &&
attributeType.ContainingNamespace?.ToDisplayString() == attributeNamespace)
attributeType.ContainingNamespace.MatchesNamespace(attributeNamespace))
{
return new() { SemanticModel = ctx.SemanticModel, DeclarationSyntax = typeDecl, TypeSymbol = typeSymbol };
return (typeSymbol, typeDecl, ctx.SemanticModel);
}
}

return default;
}

static string ParseTypeName(NameSyntax nameSyntax, out int genericTypeArity)
static string GetTypeName(NameSyntax nameSyntax, out int genericTypeArity)
{
while (true)
{
Expand All @@ -90,5 +101,45 @@ static string ParseTypeName(NameSyntax nameSyntax, out int genericTypeArity)
}
}
}

static ImmutableArray<string> GetNamespaceTokens(NameSyntax nameSyntax)
{
var tokens = new List<SimpleNameSyntax>();
Traverse(nameSyntax);

SimpleNameSyntax typeName = tokens[^1];
return tokens.Select(t => t.Identifier.Text).Take(tokens.Count - 1).ToImmutableArray();

void Traverse(NameSyntax current)
{
switch (current)
{
case SimpleNameSyntax simpleName:
tokens.Add(simpleName);
break;
case QualifiedNameSyntax qualifiedName:
Traverse(qualifiedName.Left);
Traverse(qualifiedName.Right);
break;
case AliasQualifiedNameSyntax alias:
Traverse(alias.Name);
break;
default:
Debug.Fail("Unrecognized NameSyntax");
break;
}
}
}
}

// Cf. https://github.com/dotnet/roslyn/issues/72667
public static IncrementalValuesProvider<TResult> GroupBy<TSource, TKey, TResult>(
this IncrementalValuesProvider<TSource> source,
Func<TSource, TKey> keySelector,
Func<TKey, IEnumerable<TSource>, TResult> resultSelector,
IEqualityComparer<TKey>? keyComparer = null)
{
keyComparer ??= EqualityComparer<TKey>.Default;
return source.Collect().SelectMany((values, _) => values.GroupBy(keySelector, resultSelector, keyComparer));
}
}
15 changes: 15 additions & 0 deletions src/TypeShape.SourceGenerator/Helpers/RoslynHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ public static bool IsGenericTypeDefinition(this ITypeSymbol type)
=> type is INamedTypeSymbol { IsGenericType: true } namedType &&
SymbolEqualityComparer.Default.Equals(namedType.OriginalDefinition, type);

public static bool MatchesNamespace(this ISymbol? symbol, ImmutableArray<string> namespaceTokens)
{
for (int i = namespaceTokens.Length - 1; i >= 0; i--)
{
if (symbol?.Name != namespaceTokens[i])
{
return false;
}

symbol = symbol.ContainingNamespace;
}

return symbol is null or INamespaceSymbol { IsGlobalNamespace: true };
}

public static string GetGeneratedPropertyName(this ITypeSymbol type)
{
switch (type)
Expand Down
15 changes: 8 additions & 7 deletions src/TypeShape.SourceGenerator/Parser/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,21 +145,21 @@ private ImmutableEquatableArray<TypeDeclarationModel> IncludeTypesFromGenerateSh
{
if (ctx.TypeSymbol.IsGenericTypeDefinition())
{
ReportDiagnostic(GenericTypeDefinitionsNotSupported, ctx.DeclarationSyntax.GetLocation(), ctx.TypeSymbol.ToDisplayString());
ReportDiagnostic(GenericTypeDefinitionsNotSupported, ctx.Declarations.First().Syntax.GetLocation(), ctx.TypeSymbol.ToDisplayString());
continue;
}

TypeDataModelGenerationStatus generationStatus = IncludeType(ctx.TypeSymbol);

if (generationStatus is TypeDataModelGenerationStatus.UnsupportedType)
{
ReportDiagnostic(TypeNotSupported, ctx.DeclarationSyntax.GetLocation(), ctx.TypeSymbol.ToDisplayString());
ReportDiagnostic(TypeNotSupported, ctx.Declarations.First().Syntax.GetLocation(), ctx.TypeSymbol.ToDisplayString());
continue;
}

if (generationStatus is TypeDataModelGenerationStatus.InaccessibleType)
{
ReportDiagnostic(TypeNotAccessible, ctx.DeclarationSyntax.GetLocation(), ctx.TypeSymbol.ToDisplayString());
ReportDiagnostic(TypeNotAccessible, ctx.Declarations.First().Syntax.GetLocation(), ctx.TypeSymbol.ToDisplayString());
continue;
}

Expand All @@ -183,20 +183,21 @@ private static TypeId CreateTypeId(ITypeSymbol type)

private TypeDeclarationModel CreateTypeDeclaration(TypeWithAttributeDeclarationContext context, TypeId typeId)
{
string typeDeclarationHeader = FormatTypeDeclarationHeader(context.DeclarationSyntax, context.TypeSymbol, CancellationToken, out bool isPartialHierarchy);
(BaseTypeDeclarationSyntax? declarationSyntax, SemanticModel? semanticModel) = context.Declarations.First();
string typeDeclarationHeader = FormatTypeDeclarationHeader(declarationSyntax, context.TypeSymbol, CancellationToken, out bool isPartialHierarchy);

Stack<string>? parentStack = null;
for (SyntaxNode? parentNode = context.DeclarationSyntax.Parent; parentNode is BaseTypeDeclarationSyntax parentType; parentNode = parentNode.Parent)
for (SyntaxNode? parentNode = declarationSyntax.Parent; parentNode is BaseTypeDeclarationSyntax parentType; parentNode = parentNode.Parent)
{
ITypeSymbol parentSymbol = context.SemanticModel.GetDeclaredSymbol(parentType, CancellationToken)!;
ITypeSymbol parentSymbol = semanticModel.GetDeclaredSymbol(parentType, CancellationToken)!;
string parentHeader = FormatTypeDeclarationHeader(parentType, parentSymbol, CancellationToken, out bool isPartialType);
(parentStack ??= new()).Push(parentHeader);
isPartialHierarchy &= isPartialType;
}

if (!isPartialHierarchy)
{
ReportDiagnostic(GeneratedTypeNotPartial, context.DeclarationSyntax.GetLocation(), context.TypeSymbol.ToDisplayString());
ReportDiagnostic(GeneratedTypeNotPartial, declarationSyntax.GetLocation(), context.TypeSymbol.ToDisplayString());
}

return new TypeDeclarationModel
Expand Down
26 changes: 26 additions & 0 deletions tests/TypeShape.SourceGenerator.UnitTests/CompilationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,30 @@ public partial record DerivedClassWithShadowingMember : BaseClassWithShadowingMe
TypeShapeSourceGeneratorResult result = CompilationHelpers.RunTypeShapeSourceGenerator(compilation);
Assert.Empty(result.Diagnostics);
}

[Fact]
public static void MultiplePartialContextDeclarations_NoErrors()
{
Compilation compilation = CompilationHelpers.CreateCompilation("""
using TypeShape;

public static class Test
{
public static void TestMethod()
{
ITypeShape<string> stringShape = MyWitness.Default.String;
ITypeShape<int> intShape = MyWitness.Default.Int32;
}
}

[GenerateShape<int>]
public partial class MyWitness;

[GenerateShape<string>]
public partial class MyWitness;
""");

TypeShapeSourceGeneratorResult result = CompilationHelpers.RunTypeShapeSourceGenerator(compilation);
Assert.Empty(result.Diagnostics);
}
}

0 comments on commit 029bf54

Please sign in to comment.