import base64 import os from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import dh from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.serialization import ( Encoding, NoEncryption, ParameterFormat, PublicFormat, PrivateFormat, ) import cryptography.hazmat.primitives.serialization as Serialization import msgpack from Daisy.Store import Store # TODO: Different store directories per node class DHEFern: """ `🔗 Source `__ Attributes ---------- cLog Method reference to `run.Node.cLog` so we can log to the ui from here loadedParams: dict In memory representations of cryptography parameters loadedKeys: dict In memory representations of cryptography keys nodeNickname: str Name of node for isolating configs when running multiple nodes cache: Components.daisy.Cache Daisy cache for use in storing cryptography information publicKey Public key for node privateKey Private key for node """ def __init__(self, cache, nodeNickname, cLog): """ Parameters ---------- cache: Components.daisy.Cache Reference to the node instances Daisy cache nodeNickname: str Node nickname for record storage cLog Reference to `run.Node.cLog` """ self.cLog = cLog self.stores = {} self.loadedParams = {} self.loadedKeys = {} self.nodeNickname = nodeNickname self.cache = cache if os.path.exists("daisy/cryptography/{0}/param".format(nodeNickname)) == False: self.initStore("param") else: self.stores["param"] = Store("param", "cryptography", nodeNickname) self.params = self.loadParamBytes(self.stores["param"].get()["self"]) self.cLog(20, "Param store initialized") if os.path.exists("daisy/cryptography/{0}/key".format(nodeNickname)) == False: self.cLog(20, "Key store DNE, initializing") self.initStore("key") self.genKeyPair() else: self.cLog(20, "Key store exists, loading") self.stores["key"] = Store("key", "cryptography", nodeNickname) self.cLog(20, "Store loaded") # tks = self.stores["key"].get() # self.publicKey = tks["self"]["publicKey"] # self.privateKey = tks["self"]["privateKey"] self.cLog(20, "Key store initialized") def checkInMem(self, store: str, nodeID: str): """ Check if parameters or keys are loaded for node of nodeID Parameters ---------- store: str Whether to check loaded keys or parameters """ if store == "param": return nodeID in self.loadedParams.keys() elif store == "key": return nodeID in self.loadedKeys.keys() def loadRecordToMem(self, store: str, nodeID: str): """ Load record of nodeID from store to either keys or pameters """ r = self.getRecord(store, nodeID) if r == False: self.cLog( 30, "Tried to load nonexistent {0} for node {1}".format(store, nodeID) ) return False elif self.checkInMem(store, nodeID): self.cLog(10, "{0}s already deserialized, skipping".format(store)) else: if store == "param": self.loadedParams[nodeID] = self.loadParamBytes(r) elif store == "key": self.loadedKeys[nodeID] = { "publicKey": Serialization.load_pem_public_key(r["publicKey"]), "privateKey": Serialization.load_pem_private_key( r["privateKey"], None ), } return True def getRecord(self, store: str, key: str): """ Get record from store: store with key: key """ r = stores[store].getRecord(key) if r == False: self.cLog(20, "Record does not exist") return False else: return r def initStore(self, store: str): """ Initialize store: store """ self.stores[store] = Store(store, "cryptography", self.nodeNickname) if store == "param": self.genParams() self.stores[store].update("self", self.getParamsBytes(), recur=False) elif store == "key": self.stores[store].update("self", {}, recur=False) else: self.cLog(30, "Store not defined") def genParams(self): """ Generate Diffie Hellman parameters """ params = dh.generate_parameters(generator=2, key_size=2048) self.params = params return params def getParamsBytes(self): """ Get bytes encoded from self.parameters (TODO: Encode from store) """ return self.params.parameter_bytes(Encoding.PEM, ParameterFormat.PKCS3) def loadParamBytes(self, pemBytes: bytes): """ Load parameters to self.params from given bytes (TODO: Load from store) """ self.params = Serialization.load_pem_parameters(pemBytes) return self.params def genKeyPair(self, paramsOverride=False, setSelf: bool = True): """ Generate public and private keys from self.params (TODO: Gen from passed params) paramsOverride False or parameters to use (TODO) setSelf: bool Whether to set self.privateKey and self.publicKey """ privateKey = self.params.generate_private_key() if setSelf: self.privateKey = privateKey publicKey = privateKey.public_key() if setSelf: self.publicKey = publicKey self.stores["key"].update( "self", { "publicKey": self.publicKey.public_bytes( Encoding.PEM, PublicFormat.SubjectPublicKeyInfo ), "privateKey": self.privateKey.private_bytes( Encoding.PEM, PrivateFormat.PKCS8, NoEncryption() ), }, ) return [privateKey, publicKey] else: publicKey = publicKey.public_bytes( Encoding.PEM, PublicFormat.SubjectPublicKeyInfo ) privateKey = privateKey.private_bytes( Encoding.PEM, PrivateFormat.PKCS8, NoEncryption() ) return [privateKey, publicKey] def keyDerive(self, pubKey: bytes, salt: bytes, nodeID: str, params: bytes): """ Derive shared key using Diffie Hellman pubKey: bytes Public key nodeID: str PierMesh node ID params: bytes Encryption parameters """ if self.checkInMem("param", nodeID) == False: if self.getRecord("param", nodeID) == False: self.updateStore("param", nodeID, params, recur=False) self.loadRecordToMem("param", nodeID) self.cLog(20, "Precheck done for key derivation") # TODO: Load them and if private key exists load it, otherwise generate a private key if self.checkInMem("key", nodeID) == False: if self.getRecord("key", nodeID) == False: privateKey, publicKey = self.genKeyPair(setSelf=False) self.updateStore( "key", nodeID, {"publicKey": publicKey, "privateKey": privateKey} ) self.loadRecordToMem("key", nodeID) sharedKey = self.loadedKeys[nodeID]["privateKey"].exchange( Serialization.load_pem_public_key(pubKey) ) # Perform key derivation. self.cLog(20, "Performing key derivation") derivedKey = HKDF( algorithm=hashes.SHA256(), length=32, salt=salt, info=b"handshake data" ).derive(sharedKey) self.cLog(20, "Derived key") ederivedKey = base64.urlsafe_b64encode(derivedKey) tr = self.getRecord("key", nodeID) tr["derivedKey"] = ederivedKey self.updateStore("key", nodeID, tr) self.cLog(20, "Done with cryptography store updates") return ederivedKey def getSalt(self): """ Get random salt """ return os.urandom(16) def encrypt(self, data, nodeID: str, isDict: bool = True): """ Do Fernet encryption data Either bytes or dict to encrypt isDict: bool Whether data is a dictionary """ r = self.getRecord("key", nodeID) if r == False: self.cLog(20, "Node {0} not in keystore".format(nodeID)) return False else: derivedKey = r["derivedKey"] fernet = Fernet(derivedKey) if isDict: data = msgpack.dumps(data) token = fernet.encrypt(data) return token def decrypt(self, data, nodeID: str): """ Decrypt bytes and return either str or dict (TODO: Check whether to msgpack load) """ r = self.getRecord("key", nodeID) if r == False: self.cLog(20, "No record of node " + nodeID) return False elif not "derivedKey" in r.keys(): self.cLog(20, "No key derived for node " + nodeID) return False else: fernet = Fernet(self.getRecord("key", nodeID)["derivedKey"]) return msgpack.loads(fernet.decrypt(data))