6
\$\begingroup\$

For educational purposes, I have implemented the AES block cipher in python. I decided to follow the interface for block cipher modules as defined in PEP 272. The implementation consists of two python files, aes.py and block_cipher.py

aes.py (~300 lines of code)

# coding: utf-8 """ Advanced Encryption Standard. The implementations of mix_columns() and inv_mix_columns() use cl_mul with hardcoded factors in order to prevent side channel attacks """ from block_cipher import BlockCipher, BlockCipherWrapper from block_cipher import MODE_ECB, MODE_CBC, MODE_CFB, MODE_OFB, MODE_CTR __all__ = [ 'new', 'block_size', 'key_size', 'MODE_ECB', 'MODE_CBC', 'MODE_CFB', 'MODE_OFB', 'MODE_CTR' ] SBOX = ( 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, ) INV_SBOX = ( 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d, ) round_constants = (0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36) block_size = 16 key_size = None def new(key, mode, IV=None, **kwargs) -> BlockCipherWrapper: if mode in (MODE_CBC, MODE_CFB, MODE_OFB) and IV is None: raise ValueError("This mode requires an IV") cipher = BlockCipherWrapper() cipher.block_size = block_size cipher.IV = IV cipher.mode = mode cipher.cipher = AES(key) if mode == MODE_CFB: cipher.segment_size = kwargs.get('segment_size', block_size * 8) elif mode == MODE_CTR: counter = kwargs.get('counter') if counter is None: raise ValueError("CTR mode requires a callable counter object") cipher.counter = counter return cipher class AES(BlockCipher): def __init__(self, key: bytes): self.key = key self.Nk = len(self.key) // 4 # words per key if self.Nk not in (4, 6, 8): raise ValueError("invalid key size") self.Nr = self.Nk + 6 self.Nb = 4 # words per block self.state: list[list[int]] = [] # raise NotImplementedError # key schedule self.w: list[list[int]] = [] for i in range(self.Nk): self.w.append(list(key[4*i:4*i+4])) for i in range(self.Nk, self.Nb*(self.Nr+1)): tmp: list[int] = self.w[i-1] q, r = divmod(i, self.Nk) if not r: tmp = self.sub_word(self.rot_word(tmp)) tmp[0] ^= round_constants[q-1] elif self.Nk > 6 and r == 4: tmp = self.sub_word(tmp) self.w.append( [a ^ b for a, b in zip(self.w[i-self.Nk], tmp)] ) def encrypt_block(self, block: bytes) -> bytes: self.set_state(block) self.add_round_key(0) for r in range(1, self.Nr): self.sub_bytes() self.shift_rows() self.mix_columns() self.add_round_key(r) self.sub_bytes() self.shift_rows() self.add_round_key(self.Nr) return self.get_state() def decrypt_block(self, block: bytes) -> bytes: self.set_state(block) self.add_round_key(self.Nr) for r in range(self.Nr-1, 0, -1): self.inv_shift_rows() self.inv_sub_bytes() self.add_round_key(r) self.inv_mix_columns() self.inv_shift_rows() self.inv_sub_bytes() self.add_round_key(0) return self.get_state() @staticmethod def rot_word(word: list[int]): # for key schedule return word[1:] + word[:1] @staticmethod def sub_word(word: list[int]): # for key schedule return [SBOX[b] for b in word] def set_state(self, block: bytes): self.state = [ list(block[i:i+4]) for i in range(0, 16, 4) ] def get_state(self) -> bytes: return b''.join( bytes(col) for col in self.state ) def add_round_key(self, r: int): round_key = self.w[r*self.Nb:(r+1)*self.Nb] for col, word in zip(self.state, round_key): for row_index in range(4): col[row_index] ^= word[row_index] def mix_columns(self): for i, word in enumerate(self.state): new_word = [] for j in range(4): # element wise cl mul with constants 2, 3, 1, 1 value = (word[0] << 1) value ^= (word[1] << 1) ^ word[1] value ^= word[2] ^ word[3] # polynomial reduction in constant time value ^= 0x11b & -(value >> 8) new_word.append(value) # rotate word in order to match the matrix multiplication word = self.rot_word(word) self.state[i] = new_word def inv_mix_columns(self): for i, word in enumerate(self.state): new_word = [] for j in range(4): # element wise cl mul with constants 0xe, 0xb, 0xd, 0x9 value = (word[0] << 3) ^ (word[0] << 2) ^ (word[0] << 1) value ^= (word[1] << 3) ^ (word[1] << 1) ^ word[1] value ^= (word[2] << 3) ^ (word[2] << 2) ^ word[2] value ^= (word[3] << 3) ^ word[3] # polynomial reduction in constant time value ^= (0x11b << 2) & -(value >> 10) value ^= (0x11b << 1) & -(value >> 9) value ^= 0x11b & -(value >> 8) new_word.append(value) # rotate word in order to match the matrix multiplication word = self.rot_word(word) self.state[i] = new_word def shift_rows(self): for row_index in range(4): row = [ col[row_index] for col in self.state ] row = row[row_index:] + row[:row_index] for col_index in range(4): self.state[col_index][row_index] = row[col_index] def inv_shift_rows(self): for row_index in range(4): row = [ col[row_index] for col in self.state ] row = row[-row_index:] + row[:-row_index] for col_index in range(4): self.state[col_index][row_index] = row[col_index] def sub_bytes(self): for col in self.state: for row_index in range(4): col[row_index] = SBOX[col[row_index]] def inv_sub_bytes(self): for col in self.state: for row_index in range(4): col[row_index] = INV_SBOX[col[row_index]] def print_state(self): # debug function for row_index in range(4): print(' '.join(f'{col[row_index]:02x}' for col in self.state)) print() def main(): key = bytes.fromhex('2b 7e 15 16 28 ae d2 a6 ab f7 15 88 09 cf 4f 3c') cipher = new(key, MODE_ECB) plain_text = bytes.fromhex('32 43 f6 a8 88 5a 30 8d 31 31 98 a2 e0 37 07 34') print(plain_text) cipher_text = cipher.encrypt(plain_text) print(cipher_text) print(cipher.decrypt(cipher_text)) if __name__ == '__main__': main() 

block_cipher.py (~200 lines of code)

# coding: utf-8 """ this module provides the following classes: BlockCipher Counter BlockCipherWrapper """ MODE_ECB = 1 MODE_CBC = 2 MODE_CFB = 3 MODE_PGP = 4 # optional MODE_OFB = 5 MODE_CTR = 6 class BlockCipher: def encrypt_block(self, block: bytes) -> bytes: raise NotImplementedError def decrypt_block(self, block: bytes) -> bytes: raise NotImplementedError # <removed lines of code> class Counter: def __init__(self, nonce: bytes, block_size: int, byte_order='big'): self.nonce = nonce self.counter_size = block_size - len(nonce) self.byte_order = byte_order self.counter = 0 def __call__(self) -> bytes: out = self.nonce + self.counter.to_bytes( self.counter_size, self.byte_order ) self.counter += 1 return out class BlockCipherWrapper: def __init__(self): """initiate instance attributes.""" # PEP 272 required attributes self.block_size: int = NotImplemented # measured in bytes self.IV: bytes = NotImplemented # initialization vector # other attributes self.mode: int = NotImplemented self.cipher: BlockCipher = NotImplemented self.counter: Counter = NotImplemented self.segment_size: int = NotImplemented def encrypt(self, byte_string: bytes) -> bytes: if self.mode == MODE_CFB and len(byte_string) * 8 % self.segment_size: raise ValueError("message length doesn't match segment size") if self.mode == MODE_CFB and self.segment_size & 7: raise NotImplementedError if self.mode != MODE_CFB and len(byte_string) % self.block_size: raise ValueError("message length doesn't match block size") blocks = [ byte_string[i:i+self.block_size] for i in range(0, len(byte_string), self.block_size) ] if self.mode == MODE_ECB: return b''.join([ self.cipher.encrypt_block(block) for block in blocks ]) elif self.mode == MODE_CBC: cipher_blocks = [self.IV] for block in blocks: cipher_blocks.append( self.cipher.encrypt_block( self.xor(block, cipher_blocks[-1]) ) ) return b''.join(cipher_blocks[1:]) elif self.mode == MODE_CFB: s = self.segment_size >> 3 cipher = b'' current_input = self.IV while byte_string: cipher += self.xor( byte_string[:s], self.cipher.encrypt_block(current_input)[:s] ) byte_string = byte_string[s:] current_input = current_input[s:] + cipher[-s:] return cipher elif self.mode == MODE_PGP: raise NotImplementedError elif self.mode == MODE_OFB: last_output = self.IV cipher_blocks = [self.IV] for block in blocks: last_output = self.cipher.encrypt_block(last_output) cipher_blocks.append(self.xor(block, last_output)) return b''.join(cipher_blocks[1:]) elif self.mode == MODE_CTR: cipher_blocks = [] for block in blocks: ctr = self.counter() if len(ctr) != self.block_size: raise ValueError("counter has the wrong size") cipher_blocks.append( self.xor(self.cipher.encrypt_block(ctr), block) ) return b''.join(cipher_blocks) else: raise NotImplementedError("This mode is not supported") def decrypt(self, byte_string: bytes) -> bytes: if self.mode == MODE_CFB and len(byte_string) * 8 % self.segment_size: raise ValueError("message length doesn't match segment size") if self.mode == MODE_CFB and self.segment_size & 7: raise NotImplementedError if self.mode != MODE_CFB and len(byte_string) % self.block_size: raise ValueError("message length doesn't match block size") # split up into blocks blocks = [ byte_string[i:i+self.block_size] for i in range(0, len(byte_string), self.block_size) ] if self.mode == MODE_ECB: return b''.join([ self.cipher.decrypt_block(block) for block in blocks ]) elif self.mode == MODE_CBC: plain_blocks = [] blocks.insert(0, self.IV) for i in range(1, len(blocks)): plain_blocks.append(self.xor( self.cipher.decrypt_block(blocks[i]), blocks[i-1] )) return b''.join(plain_blocks) elif self.mode == MODE_CFB: s = self.segment_size >> 3 plain = b'' current_input = self.IV while byte_string: plain += self.xor( byte_string[:s], self.cipher.encrypt_block(current_input)[:s] ) current_input = current_input[s:] + byte_string[:s] byte_string = byte_string[s:] return plain elif self.mode == MODE_PGP: raise NotImplementedError("PGP mode is not supported") elif self.mode == MODE_OFB: return self.encrypt(byte_string) elif self.mode == MODE_CTR: return self.encrypt(byte_string) else: raise ValueError("unknown mode") def xor(self, block1, block2): size = ( self.segment_size >> 3 if self.mode == MODE_CFB else self.block_size ) if not (len(block1) == len(block2) == size): raise ValueError(str(size)) return bytes([block1[i] ^ block2[i] for i in range(size)]) def main(): pass if __name__ == '__main__': main() 
\$\endgroup\$

    2 Answers 2

    4
    \$\begingroup\$

    That is a lot of code to review, so this review will only touch on more superficial things.

    1. I'm not familiar with the pattern of modifying __all__, but it seems like using an underscore at the start of names which should not be available from the outside would be a less "magic" way of achieving the same thing.
    2. black and isort can help make the files more idiomatic, even if that means the very long lists will be linearized.
    3. Stricter type annotations (for all parameters and late-initialized variables) would make the code more self-documenting.
    4. pylint and flake8 can notify you of other non-idiomatic code, such as top-level variable names which are not all uppercase.
    5. key_size doesn't seem to be used for anything. Edit: Since it's required by PEP 272, I would suggest adding a comment to this effect. Otherwise it's pretty likely to be removed by the next developer unless its presence is actually enforced by commit hooks and/or CI.
    6. Some of the variable and field names are not useful, like all the single letter ones and tmp.
    7. Does "inv" in INV_SBOX stand for "inverse"? If so, can this variable be generated from SBOX?
    8. Several things here could usefully be enums, such as SBOX and the modes.
    9. @staticmethod is a code smell (for anyone tempted to comment, that doesn't always mean it's a bad thing). Often it's done because the IDE suggests it, but in my experience this is usually not helpful. Either because there's no reason to ever call it directly or because it is such a generic piece of code (often idempotent input transformation) that it really belongs outside the class. In this case, rot_word and sub_word are not actually called on the class. Whether they belong as functions or internal methods would depend on how they are likely to be used and reused.
    10. set_state and get_state both transform the state. I would normally expect only one of them to actually transform it.
    11. The stuff in aes.py's main function belongs in a test case. I would expect main to parse arguments and to dispatch an action for the rest of the code.
    12. The BlockCipher class doesn't do anything, and there is only one child class, so it looks like it can be removed without losing anything.
    13. xor seems like a strange name for a method which doesn't just return first ^ second. Is there a name which more closely says what it actually does?
    14. block_cipher.py has a no-op main function. I'm sure it's meant to eventually contain something, but I would normally reject that as part of a merge request because it's effectively dead code.
    \$\endgroup\$
    2
    • \$\begingroup\$thanks for your answer! you're right, key_size isn't used for anything but PEP 272 requires it for compatibility with other cipher modules / implementations. thanks for the tool recommendations, I'll check them out\$\endgroup\$
      – Aemyl
      CommentedDec 23, 2020 at 21:06
    • \$\begingroup\$xor does just return first^second after some checking.\$\endgroup\$CommentedDec 24, 2020 at 16:14
    1
    \$\begingroup\$

    Advanced Encryption Standard.

    Please simply include a well defined reference to the standard (it's FIPS 197 by NIST). If you do use any of the standard notations in there, indicate that as well, it will make your code much easier to understand.

    'MODE_ECB', 'MODE_CBC', 'MODE_CFB', 'MODE_OFB', 'MODE_CTR' 

    Modes of operation are not part of the block cipher itself. I'm not that familiar with the Python language, but if possible the import of these modes should be avoided for the block cipher implementation itself.

    INV_SBOX = (

    The blocks are nicely formed and well named. But I wonder why there are so many double empty lines before some statements, and this one is missing a single one.

    def new(key, mode, IV=None, **kwargs) -> BlockCipherWrapper:

    There should be no need to repeat this for each and every cipher you implement. Also, the more modes are added, the more this will expand. What about CCM, EAX, GCM, SIV, disk modes? This is where you can have a Mode class and perform the initialization within the implementations of that class.

     if self.Nk not in (4, 6, 8): 

    Better define constants for 128, 192 and 256 and then use a divide by BYTE_SIZE or just 8 if you think that's obvious enough.

     # key schedule 

    Here you could point to the right chapter in the standard. The code is clear in this section when it comes to how the algorithm is implemented, but you're not really explaining what you are doing (apart from performing the key schedule, obviously).


    The AES implementation itself is nice and well structured enough. The names are fine, the code looks fine. I'd add a few more lines about the "what", but since it is unlikely to change much and the function of these methods is well understood, you can be excused for not including it.


    Generally I would take a look at when you are including empty lines. The empty line directly after a method declaration doesn't look good to me, especially since you yourself choose to ignore it for comments. This makes the code look slightly unbalanced to me. I'd just directly start after the method declaration myself, less is more.


    MODE_PGP = 4 # optional

    What's optional? Optional for who?

    def encrypt(self, byte_string: bytes) -> bytes:

    Here and during decrypt you are falling in the same trap. Use separate classes and maybe defer to them, but please do not implement all modes in a single method. Check for instance how the Bouncy Castle lightweight API is implemented.

    Also, it isn't clear to me if you need to handle a full message at once or not.

    elif self.mode == MODE_PGP: raise NotImplementedError

    Either leave it out or implement it, but don't leave it hanging. Note that there doesn't seem to be any way to test if the implementation is available. What about returning a set of enumerations to see which modes are available?

    def xor(self, block1, block2): 

    This is a bit of a mixed bag method. You should at the very least make it private (__xor).

    \$\endgroup\$

      Start asking to get answers

      Find the answer to your question by asking.

      Ask question

      Explore related questions

      See similar questions with these tags.