diff --git a/DiscordOAuth.cs b/DiscordOAuth.cs index 9db2274..88a51e7 100644 --- a/DiscordOAuth.cs +++ b/DiscordOAuth.cs @@ -56,13 +56,19 @@ public class DiscordOAuth return uriBuilder.ToString(); } - public static bool TryGetCode(HttpRequest request, out string? code) + public static bool TryGetCode(HttpRequest request, string? state, out string? code) { code = null; if (request.Query.TryGetValue("code", out StringValues codeValues)) { - code = codeValues[0]; - return true; + if (request.Query.TryGetValue("state", out StringValues stateValues)) + { + if (stateValues.FirstOrDefault() == state) + { + code = codeValues; + return true; + } + } } return false; @@ -70,7 +76,10 @@ public class DiscordOAuth public static bool TryGetCode(HttpContext context, out string? code) { - var b = TryGetCode(context.Request, out var a); + var state = context.Session.TryGetValue("state", out byte[] stateBytes) + ? Encoding.UTF8.GetString(stateBytes) + : null; + var b = TryGetCode(context.Request, state, out var a); code = a; return b; } @@ -109,6 +118,9 @@ public class DiscordOAuth private async Task GetInformationAsync(HttpContext context, string endpoint) where T : class { + var state = context.Session.TryGetValue("state", out byte[] stateBytes) + ? Encoding.UTF8.GetString(stateBytes) + : string.Empty; if (AccessToken is null) { if (!TryGetCode(context, out var code)) return null; @@ -196,6 +208,9 @@ public class DiscordOAuth public async Task JoinGuildAsync(HttpContext context, GuildOptions options) { + string state = context.Session.TryGetValue("state", out byte[] stateBytes) + ? Encoding.UTF8.GetString(stateBytes) + : string.Empty; if (AccessToken is null) { if (!TryGetCode(context, out var code)) return false; diff --git a/x3rt.DiscordOAuth2.csproj b/x3rt.DiscordOAuth2.csproj index 5b1008d..94f87dc 100644 --- a/x3rt.DiscordOAuth2.csproj +++ b/x3rt.DiscordOAuth2.csproj @@ -14,7 +14,7 @@ LICENSE https://github.com/x3rt/x3rt.DiscordOAuth2 GIT - 1.0.1 + 1.0.2