Add HTTP WebSockets (#3)

This commit is contained in:
Alessandro Proto 2023-01-17 22:18:22 +01:00 committed by GitHub
parent 916662bebd
commit a0ffd67b78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 302 additions and 33 deletions

View file

@ -2,6 +2,24 @@ local http = require("http")
local event = require("event")
local expect = require("expect").expect
local WebSocketHandle = {}
function WebSocketHandle:close()
self:closeAsync()
local _, id
repeat
_, id = event.pull("websocket_close")
until id == self.requestId
end
function WebSocketHandle:receive()
local _, id, par
repeat
_, id, par = event.pull("websocket_message")
until id == self.requestId
return par
end
function http.request(url, body, headers, options)
expect(1, url, "string")
expect(2, body, "string", "nil")
@ -37,7 +55,38 @@ function http.post(url, body, headers, options)
expect(1, url, "string")
expect(2, body, "string", "nil")
expect(3, headers, "table", "nil")
expect(4, options, "table", "nil")
expect(4, options, "table", "nil")
return http.request(url, body, headers, options)
end
end
local function buildWebsocketHandle(requestId, handle)
handle.requestId = requestId
local metatable = getmetatable(handle) or {}
metatable.__index = WebSocketHandle
setmetatable(handle, metatable)
return handle
end
function http.websocket(url, headers)
expect(1, url, "string")
expect(2, headers, "table", "nil")
if not http.checkURL(url) then
return nil, "Invalid URL"
end
local requestId = http.websocketAsync(url, headers)
local ev, id, par
repeat
ev, id, par = event.pull("websocket_connect", "websocket_failure")
until id == requestId
if ev == "http_failure" then
return nil, par
end
return buildWebsocketHandle(requestId, par)
end

View file

@ -1,8 +0,0 @@
namespace Capy64.Configuration;
class HTTP
{
public bool Enable { get; set; } = true;
public string[] Blacklist { get; set; }
public WebSockets WebSockets { get; set; }
}

View file

@ -1,7 +0,0 @@
namespace Capy64.Configuration;
class WebSockets
{
public bool Enable { get; set; } = true;
public int MaxActiveConnections { get; set; } = 5;
}

View file

@ -0,0 +1,102 @@
using Capy64.LuaRuntime.Libraries;
using KeraLua;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Capy64.LuaRuntime.Handlers;
public class WebSocketHandle : IHandle
{
private ClientWebSocket _client;
private long _requestId;
private static IGame _game;
public WebSocketHandle(ClientWebSocket client, long requestId, IGame game)
{
_client = client;
_requestId = requestId;
_game = game;
}
private static readonly Dictionary<string, LuaFunction> functions = new()
{
["send"] = L_Send,
["closeAsync"] = L_CloseAsync,
};
public void Push(Lua L, bool newTable = true)
{
if (newTable)
L.NewTable();
// metatable
L.NewTable();
L.PushString("__close");
L.PushCFunction(L_CloseAsync);
L.SetTable(-3);
L.PushString("__gc");
L.PushCFunction(L_CloseAsync);
L.SetTable(-3);
L.SetMetaTable(-2);
foreach (var pair in functions)
{
L.PushString(pair.Key);
L.PushCFunction(pair.Value);
L.SetTable(-3);
}
L.PushString("_handle");
L.PushObject(this);
L.SetTable(-3);
}
private static WebSocketHandle GetHandle(Lua L, bool gc = true)
{
L.CheckType(1, LuaType.Table);
L.PushString("_handle");
L.GetTable(1);
return L.ToObject<WebSocketHandle>(-1, gc);
}
private static int L_Send(IntPtr state)
{
var L = Lua.FromIntPtr(state);
var data = L.CheckBuffer(2);
var h = GetHandle(L, false);
if (h is null || h._client.State == WebSocketState.Closed)
L.Error("connection is closed");
h._client.SendAsync(data, WebSocketMessageType.Text, true, CancellationToken.None);
return 0;
}
private static int L_CloseAsync(IntPtr state)
{
var L = Lua.FromIntPtr(state);
var h = GetHandle(L, true);
if (h is null || h._client.State == WebSocketState.Closed)
return 0;
h._client.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None)
.ContinueWith(async task =>
{
await task;
_game.LuaRuntime.PushEvent("websocket_close", h._requestId);
});
HTTP.WebSocketConnections.Remove(h);
return 0;
}
}

View file

@ -7,18 +7,29 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
namespace Capy64.LuaRuntime.Libraries;
#nullable enable
public class HTTP : IPlugin
{
private static IGame _game;
private static HttpClient _client;
private static long RequestId;
private static HttpClient _httpClient;
private static long _requestId;
public static readonly HashSet<WebSocketHandle> WebSocketConnections = new();
private readonly IConfiguration _configuration;
public static readonly string UserAgent = $"Capy64/{Capy64.Version}";
private static IConfiguration _configuration;
private readonly LuaRegister[] HttpLib = new LuaRegister[]
{
new()
{
name = "checkURL",
function = L_CheckUrl,
},
new()
{
name = "requestAsync",
@ -26,17 +37,17 @@ public class HTTP : IPlugin
},
new()
{
name = "checkURL",
function = L_CheckUrl,
name = "websocketAsync",
function = L_WebsocketAsync,
},
new(),
};
public HTTP(IGame game, IConfiguration configuration)
{
_game = game;
RequestId = 0;
_client = new();
_client.DefaultRequestHeaders.Add("User-Agent", $"Capy64/{Capy64.Version}");
_requestId = 0;
_httpClient = new();
_httpClient.DefaultRequestHeaders.Add("User-Agent", UserAgent);
_configuration = configuration;
}
@ -53,10 +64,29 @@ public class HTTP : IPlugin
return 1;
}
public static bool TryGetUri(string url, out Uri? uri)
private static readonly string[] _allowedSchemes = new[]
{
return (Uri.TryCreate(url, UriKind.Absolute, out uri)
&& uri?.Scheme == Uri.UriSchemeHttp) || uri?.Scheme == Uri.UriSchemeHttps;
Uri.UriSchemeHttp,
Uri.UriSchemeHttps,
Uri.UriSchemeWs,
Uri.UriSchemeWss,
};
public static bool TryGetUri(string url, out Uri uri)
{
return Uri.TryCreate(url, UriKind.Absolute, out uri!) && _allowedSchemes.Contains(uri.Scheme);
}
private static int L_CheckUrl(IntPtr state)
{
var L = Lua.FromIntPtr(state);
var url = L.CheckString(1);
var isValid = TryGetUri(url, out _);
L.PushBoolean(isValid);
return 1;
}
private static int L_Request(IntPtr state)
@ -151,9 +181,9 @@ public class HTTP : IPlugin
? new HttpMethod((string)value)
: request.Content is not null ? HttpMethod.Post : HttpMethod.Get;
var requestId = RequestId++;
var requestId = _requestId++;
var reqTask = _client.SendAsync(request);
var reqTask = _httpClient.SendAsync(request);
reqTask.ContinueWith(async (task) =>
{
@ -218,15 +248,118 @@ public class HTTP : IPlugin
return 1;
}
private static int L_CheckUrl(IntPtr state)
private static int L_WebsocketAsync(IntPtr state)
{
var L = Lua.FromIntPtr(state);
var wsSettings = _configuration.GetSection("HTTP:WebSockets");
if (!wsSettings.GetValue<bool>("Enable"))
{
L.Error("WebSockets are disabled");
return 0;
}
if (WebSocketConnections.Count >= wsSettings.GetValue<int>("MaxActiveConnections"))
{
L.Error("Max connections reached");
return 0;
}
var url = L.CheckString(1);
if (!TryGetUri(url, out var uri))
{
L.ArgumentError(1, "invalid request url");
return 0;
}
var isValid = TryGetUri(url, out _);
var requestId = _requestId++;
L.PushBoolean(isValid);
var wsClient = new ClientWebSocket();
wsClient.Options.SetRequestHeader("User-Agent", UserAgent);
if (L.IsTable(2)) // headers
{
L.PushCopy(2);
L.PushNil();
while (L.Next(-2))
{
L.PushCopy(-2);
var k = L.CheckString(-1);
if (L.IsStringOrNumber(-2))
{
var v = L.ToString(-2);
wsClient.Options.SetRequestHeader(k, v);
}
else if (L.IsNil(-2))
{
wsClient.Options.SetRequestHeader(k, null);
}
else
{
L.ArgumentError(3, "string, number or nil expected, got " + L.TypeName(L.Type(-2)) + " in field " + k);
}
L.Pop(2);
}
L.Pop(1);
}
var connectTask = wsClient.ConnectAsync(uri, CancellationToken.None);
connectTask.ContinueWith(async task =>
{
if (task.IsFaulted || task.IsCanceled)
{
_game.LuaRuntime.PushEvent("websocket_failure", requestId, task.Exception?.Message);
return;
}
await task;
var handle = new WebSocketHandle(wsClient, requestId, _game);
WebSocketConnections.Add(handle);
_game.LuaRuntime.PushEvent("websocket_connect", L =>
{
L.PushInteger(requestId);
handle.Push(L, true);
return 2;
});
var buffer = new byte[4096];
var builder = new StringBuilder();
while (wsClient.State == WebSocketState.Open)
{
var result = await wsClient.ReceiveAsync(buffer, CancellationToken.None);
if (result.MessageType == WebSocketMessageType.Close)
{
await wsClient.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None);
_game.LuaRuntime.PushEvent("websocket_close", requestId);
return;
}
else
{
var data = Encoding.ASCII.GetString(buffer, 0, result.Count);
builder.Append(data);
}
if (result.EndOfMessage)
{
_game.LuaRuntime.PushEvent("websocket_message", requestId, builder.ToString());
builder.Clear();
}
}
});
L.PushInteger(requestId);
return 1;
}