My previous post talked about the ChaCha random number generator (CSPRNG) and how Google is using it in a stream cipher for encryption on low-end devices. This post talks about how to implement ChaCha in pure Python.
First of all, the only reason to implement ChaCha in pure Python is to play with it. It would be more natural and more efficient to implement ChaCha in C.
RFC 8439 gives detailed, language-neutral directions for how to implement ChaCha, including test cases for intermediate results. At its core is the function that does a “quarter round” operation on four unsigned integers. This function depends on three operations:
- addition mod 232, denoted
+
- bitwise XOR, denoted
^
, and - bit rotation, denoted
<<<=n
.
In C, the +=
operator on unsigned integers would do what the RFC denotes by +=, but in Python working with (signed) integers we need to explicitly take remainders mod 232. The Python bitwise-or operator ^
can be used directly. We’ll write a function roll
that corresponds to <<<=
.
So the following line of pseudocode from the RFC
a += b; d ^= a; d <<<= 16;
becomes
a = (a+b) % 2**32; d = roll(d^a, 16)
in Python. One way to implement roll
would be to use the bitstring
library:
from bitstring import Bits def roll(x, n): bits = Bits(uint=x, length=32) return (bits[n:] + bits[:n]).uint
Another approach, a little harder to understand but not needing an external library, would be
def roll2(x, n): return (x << n) % (2 << 31) + (x >> (32-n))
So here’s an implementation of the ChaCha quarter round:
def quarter_round(a, b, c, d): a = (a+b) % 2**32; d = roll(d^a, 16) c = (c+d) % 2**32; b = roll(b^c, 12) a = (a+b) % 2**32; d = roll(d^a, 8) c = (c+d) % 2**32; b = roll(b^c, 7) return a, b, c, d
ChaCha has a state consisting of 16 unsigned integers. A “round” of ChaCha consists of four quarter rounds, operating on four of these integers at a time. All the details are in the RFC.
Incidentally, the inner workings of the BLAKE2 secure hash function are similar to those of ChaCha.
For such purpose I prefer to work with fixed size integers imported from NumPy
>>> import numpy as np
>>> numpy.warnings.filterwarnings(‘ignore’)
>>> np.uint32(4294967295) + np.uint32(7)
6