using HISP.Security; using HISP.Util; using System; using System.Collections.Generic; using System.Linq; using System.Text; namespace HISP.Server.Network { public class WebSocket : Transport { private const string WEBSOCKET_SEED = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; private const byte WEBASSEMBLY_CONTINUE = 0x0; private const byte WEBASSEMBLY_TEXT = 0x1; private const byte WEBASSEMBLY_LENGTH_INT16 = 0x7E; private const byte WEBASSEMBLY_LENGTH_INT64 = 0x7F; private List<byte> currentMessage = new List<byte>(); private string secWebsocketKey = null; private bool handshakeDone = false; private Dictionary<string, string> parseHttpHeaders(string httpResponse) { Dictionary<string, string> httpHeaders = new Dictionary<string, string>(); string[] parts = httpResponse.Replace("\r", "").Split("\n"); foreach (string part in parts) { if (part.StartsWith("GET")) continue; if (part.Contains(":")) { string[] keyValuePairs = part.Split(":"); if (keyValuePairs.Length >= 2) httpHeaders.Add(keyValuePairs[0].Trim().ToLower(), keyValuePairs[1].Trim()); } else { continue; } } return httpHeaders; } private string deriveWebsocketSecKey(string webSocketKey) { byte[] derivedKey = Authentication.Sha1Digest(Encoding.UTF8.GetBytes(webSocketKey.Trim() + WEBSOCKET_SEED.Trim())); return Convert.ToBase64String(derivedKey); } private byte[] createHandshakeResponse(string secWebsocketKey) { return Encoding.UTF8.GetBytes(String.Join("\r\n", new string[] { "HTTP/1.1 101 Switching Protocols", "Connection: Upgrade", "Upgrade: websocket", "Sec-WebSocket-Accept: " + secWebsocketKey, "", "" })); } private byte[] parseHandshake(string handshakeResponse) { Dictionary<string, string> headers = parseHttpHeaders(handshakeResponse); string webSocketKey = null; headers.TryGetValue("sec-websocket-key", out webSocketKey); if (webSocketKey != null) { string secWebsocketKey = deriveWebsocketSecKey(webSocketKey); return createHandshakeResponse(secWebsocketKey); } return createHandshakeResponse(""); } public static bool IsStartOfHandshake(byte[] data) { return Helper.ByteArrayStartsWith(data, Encoding.UTF8.GetBytes("GET")); } public static bool IsEndOfHandshake(byte[] data) { return Helper.ByteArrayEndsWith(data, Encoding.UTF8.GetBytes("\r\n\r\n")); } public override void ProcessReceivedPackets(int available, byte[] buffer) { for (int i = 0; i < available; i++) currentPacket.Add(buffer[i]); byte[] webAsmMsg = currentPacket.ToArray(); if (!handshakeDone) { if (IsStartOfHandshake(webAsmMsg) && IsEndOfHandshake(webAsmMsg)) { string httpHandshake = Encoding.UTF8.GetString(webAsmMsg); byte[] handshakeResponse = parseHandshake(httpHandshake); base.Send(handshakeResponse); currentPacket.Clear(); handshakeDone = true; } } if (currentPacket.Count >= 2) { bool finished = (currentPacket[0] & 0b10000000) != 0; int opcode = (currentPacket[0] & 0b00001111); bool mask = (currentPacket[1] & 0b10000000) != 0; UInt64 messageLength = Convert.ToUInt64(currentPacket[1] & 0b01111111); int offset = 2; if (messageLength == WEBASSEMBLY_LENGTH_INT16) { if(currentPacket.Count >= offset + 2) { byte[] uint16Bytes = new byte[2]; Array.ConstrainedCopy(webAsmMsg, offset, uint16Bytes, 0, uint16Bytes.Length); uint16Bytes = uint16Bytes.Reverse().ToArray(); messageLength = BitConverter.ToUInt16(uint16Bytes); offset += uint16Bytes.Length; } } else if (messageLength == WEBASSEMBLY_LENGTH_INT64) { if (currentPacket.Count >= offset + 8) { byte[] uint64Bytes = new byte[8]; Array.ConstrainedCopy(webAsmMsg, offset, uint64Bytes, 0, uint64Bytes.Length); uint64Bytes = uint64Bytes.Reverse().ToArray(); messageLength = BitConverter.ToUInt64(uint64Bytes); offset += uint64Bytes.Length; } } if (mask) { switch (opcode) { case WEBASSEMBLY_TEXT: if (currentPacket.LongCount() >= (offset + 4)) { byte[] unmaskKey = new byte[4]; Array.ConstrainedCopy(buffer, offset, unmaskKey, 0, unmaskKey.Length); offset += unmaskKey.Length; for (int i = 0; i < (currentPacket.Count - offset); i++) { currentMessage.Add(Convert.ToByte(currentPacket[offset+ i] ^ unmaskKey[i % unmaskKey.Length])); } currentPacket.Clear(); } break; } if (finished) { onReceiveCallback(currentMessage.ToArray()); currentMessage.Clear(); currentPacket.Clear(); } } } } public override string Name { get { return "WebSocket"; } } public override void Send(byte[] data) { throw new NotImplementedException(); } } }