Proper State

This commit is contained in:
x3rt
2023-05-03 16:02:20 -06:00
parent c8633a0837
commit 984ee9f3e7
2 changed files with 16 additions and 23 deletions

View File

@@ -21,6 +21,7 @@ public class DiscordOAuth
private ScopesBuilder Scopes { get; set; } private ScopesBuilder Scopes { get; set; }
private string? AccessToken { get; set; } private string? AccessToken { get; set; }
public string State { get; }
public static void Configure(ulong clientId, string clientSecret, string? botToken = null) public static void Configure(ulong clientId, string clientSecret, string? botToken = null)
{ {
@@ -31,21 +32,22 @@ public class DiscordOAuth
private readonly HttpClient _httpClient = new HttpClient(); private readonly HttpClient _httpClient = new HttpClient();
public DiscordOAuth(string redirectUri, ScopesBuilder scopes, bool prompt = true) public DiscordOAuth(string redirectUri, ScopesBuilder scopes, string state, bool prompt = true)
{ {
RedirectUri = redirectUri; RedirectUri = redirectUri;
Scopes = scopes; Scopes = scopes;
Prompt = prompt; Prompt = prompt;
State = state;
} }
public string GetAuthorizationUrl(string state) public string GetAuthorizationUrl()
{ {
NameValueCollection query = HttpUtility.ParseQueryString(string.Empty); NameValueCollection query = HttpUtility.ParseQueryString(string.Empty);
query["client_id"] = ClientId.ToString(); query["client_id"] = ClientId.ToString();
query["redirect_uri"] = RedirectUri; query["redirect_uri"] = RedirectUri;
query["response_type"] = "code"; query["response_type"] = "code";
query["scope"] = Scopes.ToString(); query["scope"] = Scopes.ToString();
query["state"] = state; query["state"] = State;
query["prompt"] = Prompt ? "consent" : "none"; query["prompt"] = Prompt ? "consent" : "none";
var uriBuilder = new UriBuilder("https://discord.com/api/oauth2/authorize") var uriBuilder = new UriBuilder("https://discord.com/api/oauth2/authorize")
@@ -56,19 +58,13 @@ public class DiscordOAuth
return uriBuilder.ToString(); return uriBuilder.ToString();
} }
public static bool TryGetCode(HttpRequest request, string? state, out string? code) public static bool TryGetCode(HttpRequest request, out string? code)
{ {
code = null; code = null;
if (request.Query.TryGetValue("code", out StringValues codeValues)) if (request.Query.TryGetValue("code", out StringValues codeValues))
{ {
if (request.Query.TryGetValue("state", out StringValues stateValues)) code = codeValues;
{ return true;
if (stateValues.FirstOrDefault() == state)
{
code = codeValues;
return true;
}
}
} }
return false; return false;
@@ -76,10 +72,7 @@ public class DiscordOAuth
public static bool TryGetCode(HttpContext context, out string? code) public static bool TryGetCode(HttpContext context, out string? code)
{ {
var state = context.Session.TryGetValue("state", out byte[] stateBytes) var b = TryGetCode(context.Request, out var a);
? Encoding.UTF8.GetString(stateBytes)
: null;
var b = TryGetCode(context.Request, state, out var a);
code = a; code = a;
return b; return b;
} }
@@ -103,6 +96,11 @@ public class DiscordOAuth
return authToken; return authToken;
} }
public bool ValidateState(string state)
{
return State == state;
}
private async Task<T?> GetInformationAsync<T>(string accessToken, string endpoint) where T : class private async Task<T?> GetInformationAsync<T>(string accessToken, string endpoint) where T : class
{ {
_httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", accessToken); _httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
@@ -118,9 +116,6 @@ public class DiscordOAuth
private async Task<T?> GetInformationAsync<T>(HttpContext context, string endpoint) where T : class private async Task<T?> GetInformationAsync<T>(HttpContext context, string endpoint) where T : class
{ {
var state = context.Session.TryGetValue("state", out byte[] stateBytes)
? Encoding.UTF8.GetString(stateBytes)
: string.Empty;
if (AccessToken is null) if (AccessToken is null)
{ {
if (!TryGetCode(context, out var code)) return null; if (!TryGetCode(context, out var code)) return null;
@@ -186,6 +181,7 @@ public class DiscordOAuth
public async Task<bool> JoinGuildAsync(string accessToken, ulong userId, GuildOptions options) public async Task<bool> JoinGuildAsync(string accessToken, ulong userId, GuildOptions options)
{ {
if (BotToken is null) throw new InvalidOperationException("Bot token is not set");
var request = var request =
new HttpRequestMessage(HttpMethod.Put, new HttpRequestMessage(HttpMethod.Put,
$"https://discord.com/api/guilds/{options.GuildId}/members/{userId}"); $"https://discord.com/api/guilds/{options.GuildId}/members/{userId}");
@@ -208,9 +204,6 @@ public class DiscordOAuth
public async Task<bool> JoinGuildAsync(HttpContext context, GuildOptions options) public async Task<bool> JoinGuildAsync(HttpContext context, GuildOptions options)
{ {
string state = context.Session.TryGetValue("state", out byte[] stateBytes)
? Encoding.UTF8.GetString(stateBytes)
: string.Empty;
if (AccessToken is null) if (AccessToken is null)
{ {
if (!TryGetCode(context, out var code)) return false; if (!TryGetCode(context, out var code)) return false;

View File

@@ -17,7 +17,7 @@
<RepositoryType>GIT</RepositoryType> <RepositoryType>GIT</RepositoryType>
<PackageTags>Discord-OAuth2;Discord-OAuth-2;Discord-OAuth;DiscordOAuth;Discord;OAuth;OAuth-2;OAuth2</PackageTags> <PackageTags>Discord-OAuth2;Discord-OAuth-2;Discord-OAuth;DiscordOAuth;Discord;OAuth;OAuth-2;OAuth2</PackageTags>
<Deterministic>true</Deterministic> <Deterministic>true</Deterministic>
<Version>1.0.2</Version> <Version>1.0.3</Version>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>