diff --git a/DiscordOAuth.cs b/DiscordOAuth.cs index 88a51e7..736ab7f 100644 --- a/DiscordOAuth.cs +++ b/DiscordOAuth.cs @@ -21,6 +21,7 @@ public class DiscordOAuth private ScopesBuilder Scopes { get; set; } private string? AccessToken { get; set; } + public string State { get; } public static void Configure(ulong clientId, string clientSecret, string? botToken = null) { @@ -31,21 +32,22 @@ public class DiscordOAuth private readonly HttpClient _httpClient = new HttpClient(); - public DiscordOAuth(string redirectUri, ScopesBuilder scopes, bool prompt = true) + public DiscordOAuth(string redirectUri, ScopesBuilder scopes, string state, bool prompt = true) { RedirectUri = redirectUri; Scopes = scopes; Prompt = prompt; + State = state; } - public string GetAuthorizationUrl(string state) + public string GetAuthorizationUrl() { NameValueCollection query = HttpUtility.ParseQueryString(string.Empty); query["client_id"] = ClientId.ToString(); query["redirect_uri"] = RedirectUri; query["response_type"] = "code"; query["scope"] = Scopes.ToString(); - query["state"] = state; + query["state"] = State; query["prompt"] = Prompt ? "consent" : "none"; var uriBuilder = new UriBuilder("https://discord.com/api/oauth2/authorize") @@ -56,19 +58,13 @@ public class DiscordOAuth return uriBuilder.ToString(); } - public static bool TryGetCode(HttpRequest request, string? state, out string? code) + public static bool TryGetCode(HttpRequest request, out string? code) { code = null; if (request.Query.TryGetValue("code", out StringValues codeValues)) { - if (request.Query.TryGetValue("state", out StringValues stateValues)) - { - if (stateValues.FirstOrDefault() == state) - { - code = codeValues; - return true; - } - } + code = codeValues; + return true; } return false; @@ -76,10 +72,7 @@ public class DiscordOAuth public static bool TryGetCode(HttpContext context, out string? code) { - var state = context.Session.TryGetValue("state", out byte[] stateBytes) - ? Encoding.UTF8.GetString(stateBytes) - : null; - var b = TryGetCode(context.Request, state, out var a); + var b = TryGetCode(context.Request, out var a); code = a; return b; } @@ -103,6 +96,11 @@ public class DiscordOAuth return authToken; } + public bool ValidateState(string state) + { + return State == state; + } + private async Task GetInformationAsync(string accessToken, string endpoint) where T : class { _httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", accessToken); @@ -118,9 +116,6 @@ public class DiscordOAuth private async Task GetInformationAsync(HttpContext context, string endpoint) where T : class { - var state = context.Session.TryGetValue("state", out byte[] stateBytes) - ? Encoding.UTF8.GetString(stateBytes) - : string.Empty; if (AccessToken is null) { if (!TryGetCode(context, out var code)) return null; @@ -186,6 +181,7 @@ public class DiscordOAuth public async Task JoinGuildAsync(string accessToken, ulong userId, GuildOptions options) { + if (BotToken is null) throw new InvalidOperationException("Bot token is not set"); var request = new HttpRequestMessage(HttpMethod.Put, $"https://discord.com/api/guilds/{options.GuildId}/members/{userId}"); @@ -208,9 +204,6 @@ public class DiscordOAuth public async Task JoinGuildAsync(HttpContext context, GuildOptions options) { - string state = context.Session.TryGetValue("state", out byte[] stateBytes) - ? Encoding.UTF8.GetString(stateBytes) - : string.Empty; if (AccessToken is null) { if (!TryGetCode(context, out var code)) return false; diff --git a/x3rt.DiscordOAuth2.csproj b/x3rt.DiscordOAuth2.csproj index 5d6b821..f3dc5bf 100644 --- a/x3rt.DiscordOAuth2.csproj +++ b/x3rt.DiscordOAuth2.csproj @@ -17,7 +17,7 @@ GIT Discord-OAuth2;Discord-OAuth-2;Discord-OAuth;DiscordOAuth;Discord;OAuth;OAuth-2;OAuth2 true - 1.0.2 + 1.0.3