Partially implement state

This commit is contained in:
x3rt
2023-04-26 21:43:25 -06:00
parent 1283c16cf2
commit 38db7815c8
2 changed files with 20 additions and 5 deletions

View File

@@ -56,13 +56,19 @@ public class DiscordOAuth
return uriBuilder.ToString(); return uriBuilder.ToString();
} }
public static bool TryGetCode(HttpRequest request, out string? code) public static bool TryGetCode(HttpRequest request, string? state, out string? code)
{ {
code = null; code = null;
if (request.Query.TryGetValue("code", out StringValues codeValues)) if (request.Query.TryGetValue("code", out StringValues codeValues))
{ {
code = codeValues[0]; if (request.Query.TryGetValue("state", out StringValues stateValues))
return true; {
if (stateValues.FirstOrDefault() == state)
{
code = codeValues;
return true;
}
}
} }
return false; return false;
@@ -70,7 +76,10 @@ public class DiscordOAuth
public static bool TryGetCode(HttpContext context, out string? code) public static bool TryGetCode(HttpContext context, out string? code)
{ {
var b = TryGetCode(context.Request, out var a); var state = context.Session.TryGetValue("state", out byte[] stateBytes)
? Encoding.UTF8.GetString(stateBytes)
: null;
var b = TryGetCode(context.Request, state, out var a);
code = a; code = a;
return b; return b;
} }
@@ -109,6 +118,9 @@ public class DiscordOAuth
private async Task<T?> GetInformationAsync<T>(HttpContext context, string endpoint) where T : class private async Task<T?> GetInformationAsync<T>(HttpContext context, string endpoint) where T : class
{ {
var state = context.Session.TryGetValue("state", out byte[] stateBytes)
? Encoding.UTF8.GetString(stateBytes)
: string.Empty;
if (AccessToken is null) if (AccessToken is null)
{ {
if (!TryGetCode(context, out var code)) return null; if (!TryGetCode(context, out var code)) return null;
@@ -196,6 +208,9 @@ public class DiscordOAuth
public async Task<bool> JoinGuildAsync(HttpContext context, GuildOptions options) public async Task<bool> JoinGuildAsync(HttpContext context, GuildOptions options)
{ {
string state = context.Session.TryGetValue("state", out byte[] stateBytes)
? Encoding.UTF8.GetString(stateBytes)
: string.Empty;
if (AccessToken is null) if (AccessToken is null)
{ {
if (!TryGetCode(context, out var code)) return false; if (!TryGetCode(context, out var code)) return false;

View File

@@ -14,7 +14,7 @@
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<RepositoryUrl>https://github.com/x3rt/x3rt.DiscordOAuth2</RepositoryUrl> <RepositoryUrl>https://github.com/x3rt/x3rt.DiscordOAuth2</RepositoryUrl>
<RepositoryType>GIT</RepositoryType> <RepositoryType>GIT</RepositoryType>
<Version>1.0.1</Version> <Version>1.0.2</Version>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>