Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/Foundatio.Mediator/MetadataMiddlewareScanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ public override void VisitNamedType(INamedTypeSymbol symbol)
if (symbol.HasIgnoreAttribute(_compilation))
return;

// Skip internal or private middleware from cross-assembly usage
// Only public middleware can be used across assemblies
if (symbol.DeclaredAccessibility != Accessibility.Public)
return;

// Try to extract middleware info from metadata
var middlewareInfo = ExtractMiddlewareInfo(symbol);
if (middlewareInfo != null)
Expand Down Expand Up @@ -156,6 +161,8 @@ public override void VisitNamedType(INamedTypeSymbol symbol)
BeforeMethod = beforeMethod != null ? CreateMiddlewareMethodInfo(beforeMethod) : null,
AfterMethod = afterMethod != null ? CreateMiddlewareMethodInfo(afterMethod) : null,
FinallyMethod = finallyMethod != null ? CreateMiddlewareMethodInfo(finallyMethod) : null,
DeclaredAccessibility = classSymbol.DeclaredAccessibility,
AssemblyName = classSymbol.ContainingAssembly.Name,
Diagnostics = new EquatableArray<DiagnosticInfo>([]) // No diagnostics for metadata-based
};
}
Expand Down
15 changes: 15 additions & 0 deletions src/Foundatio.Mediator/MiddlewareAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,19 @@ public static bool IsMatch(SyntaxNode node)
if (messageType == null)
return null;

// Validate accessibility - private middleware should be ignored or have [FoundatioIgnore]
if (classSymbol.DeclaredAccessibility == Accessibility.Private && !classSymbol.HasIgnoreAttribute(context.SemanticModel.Compilation))
{
diagnostics.Add(new DiagnosticInfo
{
Identifier = "FMED006",
Title = "Private Middleware Not Allowed",
Message = $"Middleware '{classSymbol.Name}' is private and cannot be used. Either make it internal or public, or mark it with [FoundatioIgnore] if it should not be discovered as middleware.",
Severity = DiagnosticSeverity.Error,
Location = LocationInfo.CreateFrom(classDeclaration)
});
}

int? order = null;

// First check [Middleware(order)] attribute
Expand Down Expand Up @@ -171,6 +184,8 @@ public static bool IsMatch(SyntaxNode node)
FinallyMethod = finallyMethod != null ? CreateMiddlewareMethodInfo(finallyMethod, context.SemanticModel.Compilation) : null,
IsStatic = isStatic,
Order = order,
DeclaredAccessibility = classSymbol.DeclaredAccessibility,
AssemblyName = classSymbol.ContainingAssembly.Name,
Diagnostics = new(diagnostics.ToArray()),
};
}
Expand Down
3 changes: 3 additions & 0 deletions src/Foundatio.Mediator/Models/MiddlewareInfo.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Foundatio.Mediator.Utility;
using Microsoft.CodeAnalysis;

namespace Foundatio.Mediator.Models;

Expand All @@ -13,6 +14,8 @@ internal readonly record struct MiddlewareInfo
public bool IsStatic { get; init; }
public bool IsAsync => BeforeMethod?.IsAsync == true || AfterMethod?.IsAsync == true || FinallyMethod?.IsAsync == true;
public int? Order { get; init; }
public Accessibility DeclaredAccessibility { get; init; }
public string AssemblyName { get; init; }
public EquatableArray<DiagnosticInfo> Diagnostics { get; init; }
}

Expand Down
100 changes: 100 additions & 0 deletions tests/Foundatio.Mediator.Tests/CrossAssemblyMiddlewareTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,106 @@ public void Handle(TestMessage msg, CancellationToken ct) { }
Assert.Contains("ImplicitMiddleware.TimingMiddleware.Finally", wrapper.Source);
}

[Fact]
public void InternalMiddlewareNotDiscoveredFromReferencedAssembly()
{
// Internal middleware in referenced assembly should NOT be discovered
var middlewareSource = """
using Foundatio.Mediator;

[assembly: FoundatioModule]

namespace SharedMiddleware;

[Middleware]
internal static class InternalMiddleware
{
public static void Before(object message) { }
}

[Middleware]
public static class PublicMiddleware
{
public static void After(object message) { }
}
""";

var middlewareCompilation = CreateMiddlewareAssembly(middlewareSource);

var handlerSource = """
using System.Threading;
using Foundatio.Mediator;

public record TestMessage;

public class TestHandler
{
public void Handle(TestMessage msg, CancellationToken ct) { }
}
""";

var (_, _, trees) = RunGenerator(handlerSource, [new MediatorGenerator()], additionalReferences: [middlewareCompilation]);

var wrapper = trees.FirstOrDefault(t => t.HintName.EndsWith("_Handler.g.cs"));

// Internal middleware should NOT be included
Assert.DoesNotContain("InternalMiddleware", wrapper.Source);

// Public middleware should be included
Assert.Contains("SharedMiddleware.PublicMiddleware.After", wrapper.Source);
}

[Fact]
public void PrivateMiddlewareNotDiscoveredFromReferencedAssembly()
{
// Private middleware in referenced assembly should NOT be discovered
var middlewareSource = """
using Foundatio.Mediator;

[assembly: FoundatioModule]

namespace SharedMiddleware;

public class Container
{
private class PrivateMiddleware
{
public static void Before(object message) { }
}
}

[Middleware]
public static class PublicMiddleware
{
public static void After(object message) { }
}
""";

var middlewareCompilation = CreateMiddlewareAssembly(middlewareSource);

var handlerSource = """
using System.Threading;
using Foundatio.Mediator;

public record TestMessage;

public class TestHandler
{
public void Handle(TestMessage msg, CancellationToken ct) { }
}
""";

var (_, _, trees) = RunGenerator(handlerSource, [new MediatorGenerator()], additionalReferences: [middlewareCompilation]);

var wrapper = trees.FirstOrDefault(t => t.HintName.EndsWith("_Handler.g.cs"));

// Private middleware should NOT be included
Assert.DoesNotContain("PrivateMiddleware", wrapper.Source);

// Public middleware should be included
Assert.Contains("SharedMiddleware.PublicMiddleware.After", wrapper.Source);
}

private static MetadataReference CreateMiddlewareAssembly(string source)
{
var parseOptions = new CSharpParseOptions(LanguageVersion.CSharp11);
Expand Down
64 changes: 64 additions & 0 deletions tests/Foundatio.Mediator.Tests/DiagnosticValidationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,68 @@ public static async Task Call<T>(IMediator m, T msg) {
var (_, genDiags, _) = RunGenerator(src, [ Gen ]);
Assert.DoesNotContain(genDiags, d => d.Id == "FMED007");
}

[Fact]
public void FMED006_PrivateMiddlewareNotAllowed()
{
var src = """
using System.Threading;
using Foundatio.Mediator;

public record Msg;
public class MsgHandler { public void Handle(Msg m, CancellationToken ct) { } }

public class Container
{
private class PrivateMiddleware
{
public static void Before(Msg m) { }
}
}
""";

var (_, genDiags, _) = RunGenerator(src, [ Gen ]);
Assert.Contains(genDiags, d => d.Id == "FMED006" && d.GetMessage().Contains("PrivateMiddleware"));
}

[Fact]
public void MiddlewareWithIgnoreAttribute_NoDiagnostic()
{
var src = """
using System.Threading;
using Foundatio.Mediator;

public record Msg;
public class MsgHandler { public void Handle(Msg m, CancellationToken ct) { } }

[FoundatioIgnore]
public class IgnoredMiddleware
{
public static void Before(Msg m) { }
}
""";

var (_, genDiags, _) = RunGenerator(src, [ Gen ]);
Assert.DoesNotContain(genDiags, d => d.Id == "FMED006");
}

[Fact]
public void InternalMiddleware_NoError()
{
var src = """
using System.Threading;
using Foundatio.Mediator;

public record Msg;
public class MsgHandler { public void Handle(Msg m, CancellationToken ct) { } }

internal static class InternalMiddleware
{
public static void Before(Msg m) { }
}
""";

var (_, genDiags, _) = RunGenerator(src, [ Gen ]);
Assert.DoesNotContain(genDiags, d => d.Id == "FMED006");
}
}
Loading