Tuesday, 20 September 2016

C# OTP Implementation with TOTP and HOTP

Sample implementation of HOTP and TOTP One Time Passwords (OTP) in C# with .NET Core

This includes an example of bacis caching which can easily be tied into an IMemoryCache instance for web usage.

Gist available at https://gist.github.com/BravoTango86/9ebb578fa4df3a0ffed28bd634f8f3c0
/*
 * Copyright (C) 2016 BravoTango86
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

using System;
using System.Collections.Generic;
using System.Security.Cryptography;
using System.Text;
using static OtpAuthenticator;

public class OtpAuthenticator : IDisposable {

    public ICachingProvider CachingProvider { get; set; }
    public int CodeLength { get; set; } = 6;
    public long HotpCounter { get; set; } = 0;
    public int TotpInterval { get; set; } = 30;
    public long TotpTimeOffsetSeconds { get; set; } = 0;

    private OtpType Type;
    private HMAC HMAC;

    public OtpAuthenticator(OtpType type, OtpAlgorithm algorithm = OtpAlgorithm.SHA1, byte[] key = null) {
        HMAC = GetHMAC(algorithm, key ?? GenerateKey(algorithm));
        Type = type;
    }

    public bool VerifyOtp(string code, byte[] challenge = null, int forward = 1, int back = 1) {
        if (string.IsNullOrEmpty(code) || code.Length != CodeLength || (CachingProvider != null && !CachingProvider.ValidateToken(code))) return false;
        long State = GetState(-back);
        List<string> Tried = new List<string>();
        for (int I = 0; I <= (forward + back); I++) {
            string Code = GetOtp(state: State + I, challenge: challenge);
            Tried.Add(code);
            if (Code == code) {
                if (CachingProvider != null) CachingProvider.CancelTokens(Type, Tried);
                if (Type == OtpType.HOTP) HotpCounter = State + I + 1;
                return true;
            }
        }
        return false;
    }

    private long GetState(int offset) {
        return (Type == OtpType.HOTP ? HotpCounter : (DateTimeOffset.UtcNow.ToUnixTimeSeconds() + TotpTimeOffsetSeconds) / TotpInterval) + offset;
    }

    public string GetOtp(int offset = 0, byte[] challenge = null) {
        return GetOtp(GetState(offset), challenge);
    }

    private string GetOtp(long state, byte[] challenge = null) {
        byte[] Input = BitConverter.GetBytes(state);
        if (BitConverter.IsLittleEndian) Array.Reverse(Input);
        if (challenge != null) {
            Array.Resize(ref Input, Input.Length + challenge.Length);
            Buffer.BlockCopy(challenge, 0, Input, 8, challenge.Length);
        }
        byte[] Hash = HMAC.ComputeHash(Input);
        int offset = Hash[Hash.Length - 1] & 0xf;
        int binary = ((Hash[offset] & 0x7f) << 24) | ((Hash[offset + 1] & 0xff) << 16) |
                        ((Hash[offset + 2] & 0xff) << 8) | (Hash[offset + 3] & 0xff);
        return (binary % (int)Math.Pow(10, CodeLength)).ToString(new string('0', CodeLength));
    }

    public string GetIntegrityValue() {
        return GetOtp(0);
    }

    public static string GetUri(OtpType type, byte[] key, string accountName, string issuer = "", OtpAlgorithm algorithm = OtpAlgorithm.SHA1,
                                        int codeLength = 6, long counter = 0, int period = 30) {
        StringBuilder SB = new StringBuilder();
        SB.AppendFormat("otpauth://{0}/", type.ToString().ToLower());
        if (!string.IsNullOrEmpty(issuer)) SB.AppendFormat("{0}:{1}?issuer={0}&", Uri.EscapeUriString(issuer), Uri.EscapeUriString(accountName));
        else SB.AppendFormat("{0}?", Uri.EscapeUriString(accountName));
        SB.AppendFormat("secret={0}&algorithm={1}&digits={2}&", Base32.Encode(key), algorithm, codeLength);
        if (type == OtpType.HOTP) SB.AppendFormat("counter={0}", counter);
        else SB.AppendFormat("period={0}", period);
        return SB.ToString();
    }

    public string GetUri(string accountName, string issuer = "") {
        return GetUri(Type, HMAC.Key, accountName, issuer = "", (OtpAlgorithm)Enum.Parse(typeof(OtpAlgorithm), HMAC.HashName),
                            CodeLength, HotpCounter, TotpInterval);
    }

    public override string ToString() {
        return GetUri("OtpGenerator");
    }

    public static byte[] GenerateKey(OtpAlgorithm algorithm) {
        return GenerateKey(GetHashLength(algorithm));
    }

    public static byte[] GenerateKey(int length) {
        using (RandomNumberGenerator RNG = RandomNumberGenerator.Create()) {
            byte[] Output = new byte[length];
            RNG.GetBytes(Output);
            return Output;
        }
    }

    public void Dispose() {
        HMAC.Dispose();
    }

    private static HMAC GetHMAC(OtpAlgorithm algorithm, byte[] key) {
        switch (algorithm) {
            case OtpAlgorithm.MD5: return new HMACMD5(key);
            case OtpAlgorithm.SHA1: return new HMACSHA1(key);
            case OtpAlgorithm.SHA256: return new HMACSHA256(key);
            case OtpAlgorithm.SHA512: return new HMACSHA512(key);
        }
        throw new InvalidOperationException();
    }

    private static int GetHashLength(OtpAlgorithm algorithm) {
        switch (algorithm) {
            case OtpAlgorithm.MD5: return 32;
            case OtpAlgorithm.SHA1: return 20;
            case OtpAlgorithm.SHA256: return 32;
            case OtpAlgorithm.SHA512: return 64;
        }
        throw new InvalidOperationException();
    }

}


public enum OtpType {
    HOTP = 0,
    TOTP = 1
}

public enum OtpAlgorithm {
    MD5 = 10,
    SHA1 = 1,
    SHA256 = 2,
    SHA512 = 3
}

public interface ICachingProvider {
    void CancelToken(OtpType type, string token);
    void CancelTokens(OtpType type, IEnumerable<string> tokens);
    bool ValidateToken(string token);
}

public class LocalCachingProvider : ICachingProvider {

    private List<string> Used = new List<string>();

    public void CancelToken(OtpType type, string token) {
        Used.Add(token);
    }

    public void CancelTokens(OtpType type, IEnumerable<string> tokens) {
        Used.AddRange(tokens);
    }

    public bool ValidateToken(string token) {
        return !Used.Contains(token);
    }
}

Borrowing the barcode generator from before...

public static void Main(string[] args) {
    Console.OutputEncoding = Encoding.UTF8;
    Console.WindowWidth = 86;
    Console.WindowHeight = 44;
    StringRenderer Renderer = new StringRenderer() { Block = "  ", Empty = "\u2588\u2588", NewLine = "\n    ", };
    EncodingOptions Options = new EncodingOptions { Height = 0, Width = 0, Margin = 1 };
    using (OtpAuthenticator Authenticator = new OtpAuthenticator(OtpType.TOTP) { CachingProvider = new LocalCachingProvider() }) {
        Console.WriteLine("\n{1}{0}{1}", BarcodeGenerator.Generate(Renderer, Authenticator.GetUri("Test Account", "OTPGenerator"), Options), Renderer.NewLine);
        while (true) {
            string Code = Console.ReadLine();
            if (Authenticator.VerifyOtp(Code)) Console.WriteLine("Code Accepted");
            else Console.WriteLine("Code Invalid");
        }
    }
}

...and scanning the generated barcode with Google Authenticator, we can check it works:

Generating 2D Barcodes in C#

Simplified example using ImageProcessorCore and ZXing for .NET Core.

Using StringRenderer will produce a string using the relevant symbols provided and ImageRenderer will return an ImageProcessorCore image.

Gist available at https://gist.github.com/BravoTango86/ca613445c740098e349caf5943b36abb

/*
 * Copyright (C) 2016 BravoTango86
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

using ImageProcessorCore;
using System.Text;
using ZXing;
using ZXing.Common;
using ZXing.Rendering;

public class BarcodeGenerator {

    public static T Generate<T>(IBarcodeRenderer<T> renderer, string content, EncodingOptions options,
                                        BarcodeFormat format = BarcodeFormat.QR_CODE) =>
        new ZXing.BarcodeWriterGeneric<T> { Format = format, Options = options, Renderer = renderer }.Write(content);

    public class StringRenderer : IBarcodeRenderer<string> {

        public string Block { get; set; }
        public string Empty { get; set; }
        public string NewLine { get; set; }

        public string Render(BitMatrix matrix, BarcodeFormat format, string content) => Render(matrix, format, content, null);

        public string Render(BitMatrix matrix, BarcodeFormat format, string content, EncodingOptions options) {
            StringBuilder SB = new StringBuilder();
            for (int Y = 0; Y < matrix.Height; Y++) {
                if (Y > 0) SB.Append(NewLine);
                for (int X = 0; X < matrix.Width; X++) SB.Append(matrix[X, Y] ? Block : Empty);
            }
            return SB.ToString();
        }
    }

    public class ImageRenderer : IBarcodeRenderer<Image> {

        public Color Background { get; set; } = Color.White;
        public Color Foreground { get; set; } = Color.Black;

        public Image Render(BitMatrix matrix, BarcodeFormat format, string content) => Render(matrix, format, content, null);

        public Image Render(BitMatrix matrix, BarcodeFormat format, string content, EncodingOptions options) {
            Image Image = new Image(matrix.Width, matrix.Height);
            using (IPixelAccessor<Color, uint> Lock = Image.Lock()) {
                for (int Y = 0; Y < matrix.Height; Y++) {
                    for (int X = 0; X < matrix.Width; X++) Lock[X, Y] = matrix[X, Y] ? Foreground : Background;
                }
            }
            return Image;
        }
    }

}

StringRenderer works quite nicely with the console...

public static void Main(string[] args) {
    Console.OutputEncoding = Encoding.UTF8;
    Console.WindowWidth = 55;
    Console.WindowHeight = 28;
    StringRenderer Renderer = new StringRenderer() {  Block = "  ", Empty = "\u2588\u2588", NewLine = "\n    ", };
    EncodingOptions Options = new EncodingOptions { Height = 0, Width = 0, Margin = 1 };
    while (true) {
        string Text = DateTime.Now.ToString();
        Console.WriteLine("\n{2}{0}{2}{1}{2}", Generate(Renderer, Text, Options), Text, Renderer.NewLine);
        if (Console.ReadKey().Key == ConsoleKey.Escape) break;
    }
}

...producing this:

Base32 Encoding and Decoding in C#

Gist available at https://gist.github.com/BravoTango86/2a085185c3b9bd8383a1f956600e515f
/*
 * Derived from https://github.com/google/google-authenticator-android/blob/master/AuthenticatorApp/src/main/java/com/google/android/apps/authenticator/Base32String.java
 * 
 * Copyright (C) 2016 BravoTango86
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

using System;
using System.Collections.Generic;
using System.Text;
using System.Text.RegularExpressions;

public static class Base32 {

    private static readonly char[] DIGITS;
    private static readonly int MASK;
    private static readonly int SHIFT;
    private static Dictionary<char, int> CHAR_MAP = new Dictionary<char, int>();
    private const string SEPARATOR = "-";

    static Base32() {
        DIGITS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567".ToCharArray();
        MASK = DIGITS.Length - 1;
        SHIFT = numberOfTrailingZeros(DIGITS.Length);
        for (int i = 0; i < DIGITS.Length; i++) CHAR_MAP[DIGITS[i]] = i;
    }

    private static int numberOfTrailingZeros(int i) {
        // HD, Figure 5-14
        int y;
        if (i == 0) return 32;
        int n = 31;
        y = i << 16; if (y != 0) { n = n - 16; i = y; }
        y = i << 8; if (y != 0) { n = n - 8; i = y; }
        y = i << 4; if (y != 0) { n = n - 4; i = y; }
        y = i << 2; if (y != 0) { n = n - 2; i = y; }
        return n - (int)((uint)(i << 1) >> 31);
    }

    public static byte[] Decode(string encoded) {
        // Remove whitespace and separators
        encoded = encoded.Trim().Replace(SEPARATOR, "");

        // Remove padding. Note: the padding is used as hint to determine how many
        // bits to decode from the last incomplete chunk (which is commented out
        // below, so this may have been wrong to start with).
        encoded = Regex.Replace(encoded, "[=]*$", "");

        // Canonicalize to all upper case
        encoded = encoded.ToUpper();
        if (encoded.Length == 0) {
            return new byte[0];
        }
        int encodedLength = encoded.Length;
        int outLength = encodedLength * SHIFT / 8;
        byte[] result = new byte[outLength];
        int buffer = 0;
        int next = 0;
        int bitsLeft = 0;
        foreach (char c in encoded.ToCharArray()) {
            if (!CHAR_MAP.ContainsKey(c)) {
                throw new DecodingException("Illegal character: " + c);
            }
            buffer <<= SHIFT;
            buffer |= CHAR_MAP[c] & MASK;
            bitsLeft += SHIFT;
            if (bitsLeft >= 8) {
                result[next++] = (byte)(buffer >> (bitsLeft - 8));
                bitsLeft -= 8;
            }
        }
        // We'll ignore leftover bits for now.
        //
        // if (next != outLength || bitsLeft >= SHIFT) {
        //  throw new DecodingException("Bits left: " + bitsLeft);
        // }
        return result;
    }


    public static string Encode(byte[] data, bool padOutput = false) {
        if (data.Length == 0) {
            return "";
        }

        // SHIFT is the number of bits per output character, so the length of the
        // output is the length of the input multiplied by 8/SHIFT, rounded up.
        if (data.Length >= (1 << 28)) {
            // The computation below will fail, so don't do it.
            throw new ArgumentOutOfRangeException("data");
        }

        int outputLength = (data.Length * 8 + SHIFT - 1) / SHIFT;
        StringBuilder result = new StringBuilder(outputLength);

        int buffer = data[0];
        int next = 1;
        int bitsLeft = 8;
        while (bitsLeft > 0 || next < data.Length) {
            if (bitsLeft < SHIFT) {
                if (next < data.Length) {
                    buffer <<= 8;
                    buffer |= (data[next++] & 0xff);
                    bitsLeft += 8;
                } else {
                    int pad = SHIFT - bitsLeft;
                    buffer <<= pad;
                    bitsLeft += pad;
                }
            }
            int index = MASK & (buffer >> (bitsLeft - SHIFT));
            bitsLeft -= SHIFT;
            result.Append(DIGITS[index]);
        }
        if (padOutput) {
            int padding = 8 - (result.Length % 8);
            if (padding > 0) result.Append(new string('=', padding == 8 ? 0 : padding));
        }
        return result.ToString();
    }

    private class DecodingException : Exception {
        public DecodingException(string message) : base(message) {
        }
    }
}

Friday, 16 September 2016

C# SNTP Client based on android.net.SntpClient

I've had major issues with time drift and Time Based One Time Password (TOTP) generation against Google Authenticator.

The android app does something like this to synchronise time, but can only provide an offset down to the nearest second:

public async static Task<long> GetTimeOffsetHTTP(string url = "http://www.google.com") {
    using (HttpClient HC = new HttpClient()) {
        HttpResponseMessage Result = await HC.SendAsync(new HttpRequestMessage(HttpMethod.Head, url));
        if (Result.Headers.Date.HasValue) return (long)((Result.Headers.Date.Value.Ticks - DateTimeOffset.Now.Ticks) / TimeSpan.TicksPerSecond);
    }
    return 0;
}

The best example of an SNTP client I could find was the one used by android itself, so with a bit of creativity this was born.

Gist available at https://gist.github.com/BravoTango86/2e221d6cac22f7e432c187c941b01648

/*
 * Derived from https://android.googlesource.com/platform/frameworks/base/+/master/core/java/android/net/SntpClient.java
 * 
 * Copyright (C) 2016 BravoTango86
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

using System;
using System.Net;
using System.Net.Sockets;
using System.Threading.Tasks;

public class SntpClient {
    public string DefaultHostName { get; set; } = "time1.google.com";
    public int DefaultPort { get; set; } = 123;
    public int DefaultTimeout { get; set; } = 1000;
    public bool Debug { get; set; } = true;
    private const int REFERENCE_TIME_OFFSET = 16;
    private const int ORIGINATE_TIME_OFFSET = 24;
    private const int RECEIVE_TIME_OFFSET = 32;
    private const int TRANSMIT_TIME_OFFSET = 40;
    private const int NTP_PACKET_SIZE = 48;
    private const int NTP_PORT = 123;
    private const int NTP_MODE_CLIENT = 3;
    private const int NTP_MODE_SERVER = 4;
    private const int NTP_MODE_BROADCAST = 5;
    private const int NTP_VERSION = 3;
    private const int NTP_LEAP_NOSYNC = 3;
    private const int NTP_STRATUM_DEATH = 0;
    private const int NTP_STRATUM_MAX = 15;
    // Number of seconds between Jan 1, 1900 and Jan 1, 1970
    // 70 years plus 17 leap days
    private const long OFFSET_1900_TO_1970 = ((365L * 70L) + 17L) * 24L * 60L * 60L;
    // system time computed from NTP server response
    public long NTPTime { get; private set; }
    // value of SystemClock.elapsedRealtime() corresponding to mNtpTime
    public long NTPTimeReference { get; private set; }
    public long NTPOffset { get; private set; }
    // round trip time in milliseconds
    public long RoundTripTime { get; private set; }

    private class InvalidServerReplyException : Exception {
        public InvalidServerReplyException(string message) : base(message) {
        }
    }

    public async Task<bool> RequestTime(IPEndPoint endPoint, int? timeout = null) => await DoRequestTime(endPoint: endPoint, timeout: timeout);

    public async Task<bool> RequestTime(string hostName, int? port = 123, int? timeout = null) => await DoRequestTime(hostName: hostName, port: port, timeout: timeout);

    public async Task<bool> RequestTime(int? timeout = null) => await RequestTime(DefaultHostName, timeout);

    private async Task<bool> DoRequestTime(IPEndPoint endPoint = null, string hostName = null, int? port = null, int? timeout = null) {
        UdpClient client = null;
        try {
            if (endPoint == null && string.IsNullOrEmpty(hostName)) throw new ArgumentException("No destination specified");
            using (client = new UdpClient()) {
                byte[] buffer = new byte[NTP_PACKET_SIZE];
                // set mode = 3 (client) and version = 3
                // mode is in low 3 bits of first byte
                // version is in bits 3-5 of first byte
                buffer[0] = NTP_MODE_CLIENT | (NTP_VERSION << 3);
                // get current time and write it to the request packet
                long requestTime = DateTimeOffset.Now.ToUnixTimeMilliseconds();
                long requestTicks = Environment.TickCount;
                writeTimeStamp(buffer, TRANSMIT_TIME_OFFSET, requestTime);
                if (endPoint != null) await client.SendAsync(buffer, buffer.Length, endPoint);
                else await client.SendAsync(buffer, buffer.Length, hostName, port ?? 123);
                // No point in using this, won't timeout
                // client.Client.ReceiveTimeout = timeout ?? DefaultTimeout;
                // buffer = (await client.ReceiveAsync()).Buffer;
                // Messy, not sure how well this works but waiting perpetually for data is hardly efficient...
                Task<UdpReceiveResult> Receiver = client.ReceiveAsync();
                if (Task.WaitAny(Task.Delay(timeout ?? DefaultTimeout), Receiver) == 0) {
                    if (Debug) Console.WriteLine("Timed out on receive");
                    return false;
                } else buffer = Receiver.Result.Buffer;
                long responseTicks = Environment.TickCount;
                long responseTime = requestTime + (responseTicks - requestTicks);
                // extract the results
                byte leap = (byte)((buffer[0] >> 6) & 0x3);
                byte mode = (byte)(buffer[0] & 0x7);
                int stratum = (int)(buffer[1] & 0xff);
                long originateTime = readTimeStamp(buffer, ORIGINATE_TIME_OFFSET);
                long receiveTime = readTimeStamp(buffer, RECEIVE_TIME_OFFSET);
                long transmitTime = readTimeStamp(buffer, TRANSMIT_TIME_OFFSET);
                /* do sanity check according to RFC */
                // TODO: validate originateTime == requestTime.
                checkValidServerReply(leap, mode, stratum, transmitTime);
                long roundTripTime = responseTicks - requestTicks - (transmitTime - receiveTime);
                // receiveTime = originateTime + transit + skew
                // responseTime = transmitTime + transit - skew
                // clockOffset = ((receiveTime - originateTime) + (transmitTime - responseTime))/2
                //             = ((originateTime + transit + skew - originateTime) +
                //                (transmitTime - (transmitTime + transit - skew)))/2
                //             = ((transit + skew) + (transmitTime - transmitTime - transit + skew))/2
                //             = (transit + skew - transit + skew)/2
                //             = (2 * skew)/2 = skew
                long clockOffset = ((receiveTime - originateTime) + (transmitTime - responseTime)) / 2;
                if (Debug) Console.WriteLine("round trip: {0}ms clock offset: {1}ms", roundTripTime, clockOffset);
                // save our results - use the times on this side of the network latency
                // (response rather than request time)
                NTPTime = responseTime + clockOffset;
                NTPTimeReference = responseTicks;
                NTPOffset = clockOffset;
                RoundTripTime = roundTripTime;
            }
        } catch (Exception e) {
            if (Debug) Console.WriteLine("request time failed: {0}", e);
            return false;
        }
        return true;
    }

    /// <summary>
    /// Provides the current <see cref="NTPTime"/> as a DateTimeOffset relative to local time or UTC
    /// </summary>
    /// <param name="local">Uses local TimeZoneInfo if true else defaults to UTC</param>
    public DateTimeOffset GetDateTimeOffset(bool local = false) =>
        TimeZoneInfo.ConvertTime(DateTimeOffset.FromUnixTimeMilliseconds(NTPTime), local ? TimeZoneInfo.Local : TimeZoneInfo.Utc);

    /// <summary>
    /// Provides the current NTPOffset as a TimeSpan 
    /// </summary>
    public TimeSpan GetOffset() => TimeSpan.FromMilliseconds(NTPOffset);

    private static void checkValidServerReply(byte leap, byte mode, int stratum, long transmitTime) {
        if (leap == NTP_LEAP_NOSYNC) {
            throw new InvalidServerReplyException("unsynchronized server");
        }
        if ((mode != NTP_MODE_SERVER) && (mode != NTP_MODE_BROADCAST)) {
            throw new InvalidServerReplyException("untrusted mode: " + mode);
        }
        if ((stratum == NTP_STRATUM_DEATH) || (stratum > NTP_STRATUM_MAX)) {
            throw new InvalidServerReplyException("untrusted stratum: " + stratum);
        }
        if (transmitTime == 0) {
            throw new InvalidServerReplyException("zero transmitTime");
        }
    }

    /**
     * Reads an unsigned 32 bit big endian number from the given offset in the buffer.
     */
    private long read32(byte[] buffer, int offset) {
        byte b0 = buffer[offset];
        byte b1 = buffer[offset + 1];
        byte b2 = buffer[offset + 2];
        byte b3 = buffer[offset + 3];
        // convert signed bytes to unsigned values
        int i0 = ((b0 & 0x80) == 0x80 ? (b0 & 0x7F) + 0x80 : b0);
        int i1 = ((b1 & 0x80) == 0x80 ? (b1 & 0x7F) + 0x80 : b1);
        int i2 = ((b2 & 0x80) == 0x80 ? (b2 & 0x7F) + 0x80 : b2);
        int i3 = ((b3 & 0x80) == 0x80 ? (b3 & 0x7F) + 0x80 : b3);
        return ((long)i0 << 24) + ((long)i1 << 16) + ((long)i2 << 8) + (long)i3;
    }

    /**
     * Reads the NTP time stamp at the given offset in the buffer and returns
     * it as a system time (milliseconds since January 1, 1970).
     */
    private long readTimeStamp(byte[] buffer, int offset) {
        long seconds = read32(buffer, offset);
        long fraction = read32(buffer, offset + 4);
        // Special case: zero means zero.
        if (seconds == 0 && fraction == 0) {
            return 0;
        }
        return ((seconds - OFFSET_1900_TO_1970) * 1000) + ((fraction * 1000L) / 0x100000000L);
    }

    /**
     * Writes system time (milliseconds since January 1, 1970) as an NTP time stamp
     * at the given offset in the buffer.
     */
    private void writeTimeStamp(byte[] buffer, int offset, long time) {
        // Special case: zero means zero.
        if (time == 0) {
            //Arrays.fill(buffer, offset, offset + 8, (byte)0x00);
            Buffer.BlockCopy(new byte[8], 0, buffer, offset, 8);
            return;
        }
        long seconds = time / 1000L;
        long milliseconds = time - seconds * 1000L;
        seconds += OFFSET_1900_TO_1970;
        // write seconds in big endian format
        buffer[offset++] = (byte)(seconds >> 24);
        buffer[offset++] = (byte)(seconds >> 16);
        buffer[offset++] = (byte)(seconds >> 8);
        buffer[offset++] = (byte)(seconds >> 0);
        long fraction = milliseconds * 0x100000000L / 1000L;
        // write fraction in big endian format
        buffer[offset++] = (byte)(fraction >> 24);
        buffer[offset++] = (byte)(fraction >> 16);
        buffer[offset++] = (byte)(fraction >> 8);
        // low order bits should be random data
        buffer[offset++] = (byte)(new Random().Next(255));
    }

}

Wednesday, 14 September 2016

Encrypting Data for Web Browser Push API Notifications

This isn't perfect nor optimised for production use but should work with Firefox and Chrome.

Unfortunately requires Bouncy Castle due to the lack of native encryption support in .NET Core.

FirebaseServerKey is needed if you are submitted messages to Chrome browsers.

Gist available at https://gist.github.com/BravoTango86/2265a0eb1abcd669a9c8a2d60bd653c3


/* 
 * Built for .NET Core 1.0 on Windows 10 with Portable.BouncyCastle v1.8.1.1
 * 
 * Tested on Chrome v53.0.2785.113 m (64-bit) and Firefox 48.0.2
 * 
 * Massive thanks to Peter Beverloo for the following:
 * https://docs.google.com/document/d/1_kWRLJHRYN0KH73WipFyfIXI1UzZ5IyOYSs-y_mLxEE/
 * https://tests.peter.sh/push-encryption-verifier/
 * 
 * Some more useful links:
 * https://developers.google.com/web/updates/2016/03/web-push-encryption?hl=en
 * https://github.com/web-push-libs/web-push/blob/master/src/index.js
 * 
 * Copyright (C) 2016 BravoTango86
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

using Microsoft.AspNetCore.WebUtilities;
using Org.BouncyCastle.Asn1.X9;
using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Agreement;
using Org.BouncyCastle.Crypto.Generators;
using Org.BouncyCastle.Crypto.Parameters;
using Org.BouncyCastle.Math;
using Org.BouncyCastle.Security;
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;

public class WebPushHelper {

    private const string FirebaseServerKey = "";

    public static bool SendNotification(JsonSubscription sub, byte[] data, int ttl = 0, ushort padding = 0,
                                        bool randomisePadding = false) {
        return SendNotification(endpoint: sub.endpoint,
                                data: data,
                                userKey: WebEncoders.Base64UrlDecode(sub.keys["p256dh"]),
                                userSecret: WebEncoders.Base64UrlDecode(sub.keys["auth"]),
                                ttl: ttl,
                                padding: padding,
                                randomisePadding: randomisePadding);
    }

    public static bool SendNotification(string endpoint, string data, string userKey, string userSecret,
                                        int ttl = 0, ushort padding = 0, bool randomisePadding = false) {
        return SendNotification(endpoint: endpoint,
                                data: Encoding.UTF8.GetBytes(data),
                                userKey: WebEncoders.Base64UrlDecode(userKey),
                                userSecret: WebEncoders.Base64UrlDecode(userSecret),
                                ttl: ttl,
                                padding: padding,
                                randomisePadding: randomisePadding);
    }

    public static bool SendNotification(string endpoint, byte[] userKey, byte[] userSecret, byte[] data = null,
                                    int ttl = 0, ushort padding = 0, bool randomisePadding = false) {
        HttpRequestMessage Request = new HttpRequestMessage(HttpMethod.Post, endpoint);
        if (endpoint.StartsWith("https://android.googleapis.com/gcm/send/"))
            Request.Headers.TryAddWithoutValidation("Authorization", "key=" + FirebaseServerKey);
        Request.Headers.Add("TTL", ttl.ToString());
        if (data != null && userKey != null && userSecret != null) {
            EncryptionResult Package = EncryptMessage(userKey, userSecret, data, padding, randomisePadding);
            Request.Content = new ByteArrayContent(Package.Payload);
            Request.Content.Headers.ContentType = new MediaTypeHeaderValue("application/octet-stream");
            Request.Content.Headers.ContentLength = Package.Payload.Length;
            Request.Content.Headers.ContentEncoding.Add("aesgcm");
            Request.Headers.Add("Crypto-Key", "keyid=p256dh;dh=" + WebEncoders.Base64UrlEncode(Package.PublicKey));
            Request.Headers.Add("Encryption", "keyid=p256dh;salt=" + WebEncoders.Base64UrlEncode(Package.Salt));
        }
        using (HttpClient HC = new HttpClient()) {
            return HC.SendAsync(Request).Result.StatusCode == HttpStatusCode.Created;
        }
    }

    public static EncryptionResult EncryptMessage(byte[] userKey, byte[] userSecret, byte[] data,
                                                  ushort padding = 0, bool randomisePadding = false) {
        SecureRandom Random = new SecureRandom();
        byte[] Salt = new byte[16];
        Random.NextBytes(Salt);
        X9ECParameters Curve = ECNamedCurveTable.GetByName("prime256v1");
        ECDomainParameters Spec = new ECDomainParameters(Curve.Curve, Curve.G, Curve.N, Curve.H, Curve.GetSeed());
        ECKeyPairGenerator Generator = new ECKeyPairGenerator();
        Generator.Init(new ECKeyGenerationParameters(Spec, new SecureRandom()));
        AsymmetricCipherKeyPair KeyPair = Generator.GenerateKeyPair();
        ECDHBasicAgreement AgreementGenerator = new ECDHBasicAgreement();
        AgreementGenerator.Init(KeyPair.Private);
        BigInteger IKM = AgreementGenerator.CalculateAgreement(new ECPublicKeyParameters(Spec.Curve.DecodePoint(userKey), Spec));
        byte[] PRK = GenerateHKDF(userSecret, IKM.ToByteArrayUnsigned(), Encoding.UTF8.GetBytes("Content-Encoding: auth\0"), 32);
        byte[] PublicKey = ((ECPublicKeyParameters)KeyPair.Public).Q.GetEncoded(false);
        byte[] CEK = GenerateHKDF(Salt, PRK, CreateInfoChunk("aesgcm", userKey, PublicKey), 16);
        byte[] Nonce = GenerateHKDF(Salt, PRK, CreateInfoChunk("nonce", userKey, PublicKey), 12);
        if (randomisePadding && padding > 0) padding = Convert.ToUInt16(Math.Abs(Random.NextInt()) % (padding + 1));
        byte[] Input = new byte[padding + 2 + data.Length];
        Buffer.BlockCopy(ConvertInt(padding), 0, Input, 0, 2);
        Buffer.BlockCopy(data, 0, Input, padding + 2, data.Length);
        IBufferedCipher Cipher = CipherUtilities.GetCipher("AES/GCM/NoPadding");
        Cipher.Init(true, new AeadParameters(new KeyParameter(CEK), 128, Nonce));
        byte[] Message = new byte[Cipher.GetOutputSize(Input.Length)];
        Cipher.DoFinal(Input, 0, Input.Length, Message, 0);
        return new EncryptionResult() { Salt = Salt, Payload = Message, PublicKey = PublicKey };
    }

    public class EncryptionResult {
        public byte[] PublicKey { get; set; }
        public byte[] Payload { get; set; }
        public byte[] Salt { get; set; }
    }

    public class JsonSubscription {
        public string endpoint { get; set; }
        public Dictionary<string, string> keys { get; set; }
    }

    public static byte[] ConvertInt(int number) {
        byte[] Output = BitConverter.GetBytes(Convert.ToUInt16(number));
        if (BitConverter.IsLittleEndian) Array.Reverse(Output);
        return Output;
    }

    public static byte[] CreateInfoChunk(string type, byte[] recipientPublicKey, byte[] senderPublicKey) {
        List<byte> Output = new List<byte>();
        Output.AddRange(Encoding.UTF8.GetBytes($"Content-Encoding: {type}\0P-256\0"));
        Output.AddRange(ConvertInt(recipientPublicKey.Length));
        Output.AddRange(recipientPublicKey);
        Output.AddRange(ConvertInt(senderPublicKey.Length));
        Output.AddRange(senderPublicKey);
        return Output.ToArray();
    }

    public static byte[] GenerateHKDF(byte[] salt, byte[] ikm, byte[] info, int len) {
        IMac PRKGen = MacUtilities.GetMac("HmacSHA256");
        PRKGen.Init(new KeyParameter(MacUtilities.CalculateMac("HmacSHA256", new KeyParameter(salt), ikm)));
        PRKGen.BlockUpdate(info, 0, info.Length);
        PRKGen.Update((byte)1);
        byte[] Result = MacUtilities.DoFinal(PRKGen);
        if (Result.Length > len) Array.Resize(ref Result, len);
        return Result;
    }

}