diff --git a/Capy64/Assets/Lua/boot/04_http.lua b/Capy64/Assets/Lua/boot/04_http.lua index 020a900..faffa58 100644 --- a/Capy64/Assets/Lua/boot/04_http.lua +++ b/Capy64/Assets/Lua/boot/04_http.lua @@ -2,6 +2,24 @@ local http = require("http") local event = require("event") local expect = require("expect").expect +local WebSocketHandle = {} +function WebSocketHandle:close() + self:closeAsync() + local _, id + repeat + _, id = event.pull("websocket_close") + until id == self.requestId +end + +function WebSocketHandle:receive() + local _, id, par + repeat + _, id, par = event.pull("websocket_message") + until id == self.requestId + + return par +end + function http.request(url, body, headers, options) expect(1, url, "string") expect(2, body, "string", "nil") @@ -37,7 +55,38 @@ function http.post(url, body, headers, options) expect(1, url, "string") expect(2, body, "string", "nil") expect(3, headers, "table", "nil") - expect(4, options, "table", "nil") + expect(4, options, "table", "nil") return http.request(url, body, headers, options) -end \ No newline at end of file +end + +local function buildWebsocketHandle(requestId, handle) + handle.requestId = requestId + local metatable = getmetatable(handle) or {} + metatable.__index = WebSocketHandle + + setmetatable(handle, metatable) + + return handle +end + +function http.websocket(url, headers) + expect(1, url, "string") + expect(2, headers, "table", "nil") + + if not http.checkURL(url) then + return nil, "Invalid URL" + end + + local requestId = http.websocketAsync(url, headers) + local ev, id, par + repeat + ev, id, par = event.pull("websocket_connect", "websocket_failure") + until id == requestId + + if ev == "http_failure" then + return nil, par + end + + return buildWebsocketHandle(requestId, par) +end diff --git a/Capy64/Configuration/HTTP.cs b/Capy64/Configuration/HTTP.cs deleted file mode 100644 index 6e0e1d7..0000000 --- a/Capy64/Configuration/HTTP.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Capy64.Configuration; - -class HTTP -{ - public bool Enable { get; set; } = true; - public string[] Blacklist { get; set; } - public WebSockets WebSockets { get; set; } -} diff --git a/Capy64/Configuration/WebSockets.cs b/Capy64/Configuration/WebSockets.cs deleted file mode 100644 index ab112ff..0000000 --- a/Capy64/Configuration/WebSockets.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Capy64.Configuration; - -class WebSockets -{ - public bool Enable { get; set; } = true; - public int MaxActiveConnections { get; set; } = 5; -} diff --git a/Capy64/LuaRuntime/Handlers/WebSocketHandle.cs b/Capy64/LuaRuntime/Handlers/WebSocketHandle.cs new file mode 100644 index 0000000..a5ba4d3 --- /dev/null +++ b/Capy64/LuaRuntime/Handlers/WebSocketHandle.cs @@ -0,0 +1,102 @@ +using Capy64.LuaRuntime.Libraries; +using KeraLua; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Capy64.LuaRuntime.Handlers; + +public class WebSocketHandle : IHandle +{ + private ClientWebSocket _client; + private long _requestId; + private static IGame _game; + public WebSocketHandle(ClientWebSocket client, long requestId, IGame game) + { + _client = client; + _requestId = requestId; + _game = game; + } + + private static readonly Dictionary functions = new() + { + ["send"] = L_Send, + ["closeAsync"] = L_CloseAsync, + }; + + public void Push(Lua L, bool newTable = true) + { + if (newTable) + L.NewTable(); + + // metatable + L.NewTable(); + L.PushString("__close"); + L.PushCFunction(L_CloseAsync); + L.SetTable(-3); + L.PushString("__gc"); + L.PushCFunction(L_CloseAsync); + L.SetTable(-3); + L.SetMetaTable(-2); + + foreach (var pair in functions) + { + L.PushString(pair.Key); + L.PushCFunction(pair.Value); + L.SetTable(-3); + } + + L.PushString("_handle"); + L.PushObject(this); + L.SetTable(-3); + } + + private static WebSocketHandle GetHandle(Lua L, bool gc = true) + { + L.CheckType(1, LuaType.Table); + L.PushString("_handle"); + L.GetTable(1); + return L.ToObject(-1, gc); + } + + private static int L_Send(IntPtr state) + { + var L = Lua.FromIntPtr(state); + + var data = L.CheckBuffer(2); + + var h = GetHandle(L, false); + + if (h is null || h._client.State == WebSocketState.Closed) + L.Error("connection is closed"); + + h._client.SendAsync(data, WebSocketMessageType.Text, true, CancellationToken.None); + + return 0; + } + + private static int L_CloseAsync(IntPtr state) + { + var L = Lua.FromIntPtr(state); + + var h = GetHandle(L, true); + + if (h is null || h._client.State == WebSocketState.Closed) + return 0; + + h._client.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None) + .ContinueWith(async task => + { + await task; + _game.LuaRuntime.PushEvent("websocket_close", h._requestId); + }); + + HTTP.WebSocketConnections.Remove(h); + + return 0; + } +} diff --git a/Capy64/LuaRuntime/Libraries/HTTP.cs b/Capy64/LuaRuntime/Libraries/HTTP.cs index 323fa64..5a73e28 100644 --- a/Capy64/LuaRuntime/Libraries/HTTP.cs +++ b/Capy64/LuaRuntime/Libraries/HTTP.cs @@ -7,18 +7,29 @@ using System; using System.Collections.Generic; using System.Linq; using System.Net.Http; +using System.Net.WebSockets; +using System.Text; +using System.Threading; namespace Capy64.LuaRuntime.Libraries; #nullable enable public class HTTP : IPlugin { private static IGame _game; - private static HttpClient _client; - private static long RequestId; + private static HttpClient _httpClient; + private static long _requestId; + public static readonly HashSet WebSocketConnections = new(); - private readonly IConfiguration _configuration; + public static readonly string UserAgent = $"Capy64/{Capy64.Version}"; + + private static IConfiguration _configuration; private readonly LuaRegister[] HttpLib = new LuaRegister[] { + new() + { + name = "checkURL", + function = L_CheckUrl, + }, new() { name = "requestAsync", @@ -26,17 +37,17 @@ public class HTTP : IPlugin }, new() { - name = "checkURL", - function = L_CheckUrl, + name = "websocketAsync", + function = L_WebsocketAsync, }, new(), }; public HTTP(IGame game, IConfiguration configuration) { _game = game; - RequestId = 0; - _client = new(); - _client.DefaultRequestHeaders.Add("User-Agent", $"Capy64/{Capy64.Version}"); + _requestId = 0; + _httpClient = new(); + _httpClient.DefaultRequestHeaders.Add("User-Agent", UserAgent); _configuration = configuration; } @@ -53,10 +64,29 @@ public class HTTP : IPlugin return 1; } - public static bool TryGetUri(string url, out Uri? uri) + private static readonly string[] _allowedSchemes = new[] { - return (Uri.TryCreate(url, UriKind.Absolute, out uri) - && uri?.Scheme == Uri.UriSchemeHttp) || uri?.Scheme == Uri.UriSchemeHttps; + Uri.UriSchemeHttp, + Uri.UriSchemeHttps, + Uri.UriSchemeWs, + Uri.UriSchemeWss, + }; + public static bool TryGetUri(string url, out Uri uri) + { + return Uri.TryCreate(url, UriKind.Absolute, out uri!) && _allowedSchemes.Contains(uri.Scheme); + } + + private static int L_CheckUrl(IntPtr state) + { + var L = Lua.FromIntPtr(state); + + var url = L.CheckString(1); + + var isValid = TryGetUri(url, out _); + + L.PushBoolean(isValid); + + return 1; } private static int L_Request(IntPtr state) @@ -151,9 +181,9 @@ public class HTTP : IPlugin ? new HttpMethod((string)value) : request.Content is not null ? HttpMethod.Post : HttpMethod.Get; - var requestId = RequestId++; + var requestId = _requestId++; - var reqTask = _client.SendAsync(request); + var reqTask = _httpClient.SendAsync(request); reqTask.ContinueWith(async (task) => { @@ -218,15 +248,118 @@ public class HTTP : IPlugin return 1; } - private static int L_CheckUrl(IntPtr state) + private static int L_WebsocketAsync(IntPtr state) { var L = Lua.FromIntPtr(state); + var wsSettings = _configuration.GetSection("HTTP:WebSockets"); + + if (!wsSettings.GetValue("Enable")) + { + L.Error("WebSockets are disabled"); + return 0; + } + + if (WebSocketConnections.Count >= wsSettings.GetValue("MaxActiveConnections")) + { + L.Error("Max connections reached"); + return 0; + } + var url = L.CheckString(1); + if (!TryGetUri(url, out var uri)) + { + L.ArgumentError(1, "invalid request url"); + return 0; + } - var isValid = TryGetUri(url, out _); + var requestId = _requestId++; - L.PushBoolean(isValid); + var wsClient = new ClientWebSocket(); + + wsClient.Options.SetRequestHeader("User-Agent", UserAgent); + + if (L.IsTable(2)) // headers + { + L.PushCopy(2); + L.PushNil(); + + while (L.Next(-2)) + { + L.PushCopy(-2); + + var k = L.CheckString(-1); + if (L.IsStringOrNumber(-2)) + { + var v = L.ToString(-2); + + wsClient.Options.SetRequestHeader(k, v); + } + else if (L.IsNil(-2)) + { + wsClient.Options.SetRequestHeader(k, null); + } + else + { + L.ArgumentError(3, "string, number or nil expected, got " + L.TypeName(L.Type(-2)) + " in field " + k); + } + + L.Pop(2); + } + + L.Pop(1); + } + + + var connectTask = wsClient.ConnectAsync(uri, CancellationToken.None); + connectTask.ContinueWith(async task => + { + if (task.IsFaulted || task.IsCanceled) + { + _game.LuaRuntime.PushEvent("websocket_failure", requestId, task.Exception?.Message); + return; + } + + await task; + + var handle = new WebSocketHandle(wsClient, requestId, _game); + WebSocketConnections.Add(handle); + + _game.LuaRuntime.PushEvent("websocket_connect", L => + { + L.PushInteger(requestId); + + handle.Push(L, true); + + return 2; + }); + + var buffer = new byte[4096]; + var builder = new StringBuilder(); + while (wsClient.State == WebSocketState.Open) + { + var result = await wsClient.ReceiveAsync(buffer, CancellationToken.None); + if (result.MessageType == WebSocketMessageType.Close) + { + await wsClient.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None); + _game.LuaRuntime.PushEvent("websocket_close", requestId); + return; + } + else + { + var data = Encoding.ASCII.GetString(buffer, 0, result.Count); + builder.Append(data); + } + + if (result.EndOfMessage) + { + _game.LuaRuntime.PushEvent("websocket_message", requestId, builder.ToString()); + builder.Clear(); + } + } + }); + + L.PushInteger(requestId); return 1; }