Implementing token validation
@@ -8,6 +8,7 @@ pegs,
strtabs, strutils, sequtils, + httpclient, base64, math import@@ -136,7 +137,8 @@
# Manage Indexes proc createIndex*(store: Datastore, indexId, field: string) = - let query = sql("CREATE INDEX json_index_$1 ON documents(json_extract(data, ?) COLLATE NOCASE) WHERE json_valid(data)" % [indexId]) + let query = sql("CREATE INDEX json_index_$1 ON documents(json_extract(data, ?) COLLATE NOCASE) WHERE json_valid(data)" % + [indexId]) store.begin() store.db.exec(query, field) store.commit()@@ -147,7 +149,8 @@ store.begin()
store.db.exec(query) store.commit() -proc retrieveIndex*(store: Datastore, id: string, options: QueryOptions = newQueryOptions()): JsonNode = +proc retrieveIndex*(store: Datastore, id: string, + options: QueryOptions = newQueryOptions()): JsonNode = var options = options options.single = true let query = prepareSelectIndexesQuery(options)@@ -167,7 +170,7 @@ if (options.like[options.like.len-1] == '*' and options.like[0] != '*'):
let str = "json_index_" & options.like.substr(0, options.like.len-2) raw_indexes = store.db.getAllRows(query.sql, str, str & "{") else: - let str = "json_index_" & options.like.replace("*", "%") + let str = "json_index_" & options.like.replace("*", "%") raw_indexes = store.db.getAllRows(query.sql, str) else: raw_indexes = store.db.getAllRows(query.sql)@@ -176,7 +179,8 @@ for index in raw_indexes:
var matches: array[0..0, string] let fieldPeg = peg"'CREATE INDEX json_index_test ON documents(json_extract(data, \'' {[^']+}" discard index[1].match(fieldPeg, matches) - indexes.add(%[("id", %index[0].replace("json_index_", "")), ("field", %matches[0])]) + indexes.add(%[("id", %index[0].replace("json_index_", "")), ("field", + %matches[0])]) return %indexes proc countIndexes*(store: Datastore, q = "", like = ""): int64 =@@ -371,8 +375,8 @@ getCurrentExceptionMsg())
try: LOG.debug("Updating system document '$1'" % id) store.begin() - var res = store.db.execAffectedRows(SQL_UPDATE_SYSTEM_DOCUMENT, data, contenttype, - binary, currentTime(), id) + var res = store.db.execAffectedRows(SQL_UPDATE_SYSTEM_DOCUMENT, data, + contenttype, binary, currentTime(), id) if res > 0: result = $store.retrieveRawDocument(id) else:@@ -450,7 +454,7 @@ except CatchableError:
eWarn() store.rollback() -proc findDocumentId*(store: Datastore, pattern: string): string = +proc findDocumentId*(store: Datastore, pattern: string): string = var select = "SELECT id FROM documents WHERE id LIKE ? ESCAPE '\\' " var raw_document = store.db.getRow(select.sql, pattern) LOG.debug("Retrieving document '$1'" % pattern)@@ -496,7 +500,8 @@
proc countDocuments*(store: Datastore): int64 = return store.db.getRow(SQL_COUNT_DOCUMENTS)[0].parseInt -proc importFile*(store: Datastore, f: string, dir = "/", system = false, notSearchable = false): string = +proc importFile*(store: Datastore, f: string, dir = "/", system = false, + notSearchable = false): string = if not f.fileExists: raise newException(EFileNotFound, "File '$1' not found." % f) let split = f.splitFile@@ -580,21 +585,22 @@ result = newSeq[string]()
let tags_file = f.splitFile.dir / "_tags" if tags_file.fileExists: for tag in tags_file.lines: - result.add(tag) + result.add(tag) -proc importDir*(store: Datastore, dir: string, system = false, importTags = false, notSearchable = false) = +proc importDir*(store: Datastore, dir: string, system = false, + importTags = false, notSearchable = false) = var files = newSeq[string]() if not dir.dirExists: raise newException(EDirectoryNotFound, "Directory '$1' not found." % dir) for f in dir.walkDirRec(): if f.dirExists: continue - let dirs = f.split(DirSep) + let dirs = f.split(DirSep) if dirs.any(proc (s: string): bool = return s.startsWith(".")): # Ignore hidden directories and files - continue - let fileName = f.splitFile.name + continue + let fileName = f.splitFile.name if fileName == "_tags" and not importTags: # Ignore tags file unless the CLI flag was set continue@@ -685,12 +691,24 @@ LOG.level = lvNone
else: fail(103, "Invalid log level '$1'" % val) -proc processAuthConfig(configuration: JsonNode, auth: var JsonNode) = - if auth == newJNull() and configuration != newJNull() and configuration.hasKey("signature"): - LOG.debug("Authentication: Signature found, processing authentication rules in configuration.") - auth = newJObject(); - auth["access"] = newJObject(); - auth["signature"] = configuration["signature"] + +proc downloadJwks*(uri: string) = + let file = getCurrentDir() / "jwks.json" + let client = newHttpClient() + client.downloadFile(uri, file) + +proc processAuthConfig(configuration: var JsonNode, auth: var JsonNode) = + if auth == newJNull() and configuration != newJNull(): + if configuration.hasKey("jwks_uri"): + LOG.debug("Authentication: Downloading JWKS file.") + downloadJwks(configuration["jwks_uri"].getStr) + elif configuration.hasKey("signature"): + LOG.debug("Authentication: Signature found, processing authentication rules in configuration.") + auth = newJObject(); + auth["access"] = newJObject(); + auth["signature"] = configuration["signature"].getStr.replace( + "-----BEGIN CERTIFICATE-----\n", "").replace( + "\n-----END CERTIFICATE-----").strip().newJString for k, v in configuration["resources"].pairs: auth["access"][k] = newJObject() for meth, content in v.pairs:
@@ -1,5 +1,7 @@
-import openssl, base64, strutils, macros - +import std/[ + openssl, base64, strutils, macros, json, times, pegs, sequtils, os + ] +import types when defined(windows) and defined(amd64): {.passL: "-static -L"&getProjectPath()&"/litestorepkg/vendor/openssl/windows -lssl -lcrypto -lbcrypt".}@@ -7,6 +9,7 @@ elif defined(linux) and defined(amd64):
{.passL: "-static -L"&getProjectPath()&"/litestorepkg/vendor/openssl/linux -lssl -lcrypto".} elif defined(macosx) and defined(amd64): {.passL: "-Bstatic -L"&getProjectPath()&"/litestorepkg/vendor/openssl/macosx -lssl -lcrypto -Bdynamic".} + proc X509_get_pubkey(cert: PX509): EVP_PKEY {.cdecl, importc.} proc EVP_DigestVerifyInit(ctx: EVP_MD_CTX; pctx: ptr EVP_PKEY_CTX; typ: EVP_MD;@@ -15,16 +18,67 @@ proc EVP_DigestVerifyUpdate(ctx: EVP_MD_CTX; data: pointer;
len: cuint): cint {.cdecl, importc.} proc EVP_DigestVerifyFinal(ctx: EVP_MD_CTX; data: pointer; len: cuint): cint {.cdecl, importc.} + +proc raiseJwtError(msg: string) = + raise newException(EJwtValidationError, msg) + +proc getX5c*(token: JWT): string = + let file = getCurrentDir() / "jwks.json" + if not file.fileExists: + raise newException(ValueError, "JWKS file not found: " & file) + let keys = file.readFile.parseJson + if token.header.hasKey("kid"): + let kid = token.header["kid"].getStr + return keys.filterIt(it["kid"].getStr == kid)[0]["x5c"].getStr + return keys[0]["x5c"].getStr proc base64UrlDecode(encoded: string): string = let padding = 4 - (encoded.len mod 4) let base64String = encoded.replace("-", "+").replace("_", "/") & repeat("=", padding) result = base64.decode(base64String) -proc validateJwtToken*(token: string; x5c: string): bool = +proc newJwt*(token: string): JWT = let parts = token.split(".") - let sig = parts[2].base64UrlDecode - let payload = parts[0]&"."&parts[1] + result.token = token + result.payload = parts[0]&"."&parts[1] + result.header = parts[0].base64UrlDecode.parseJson + result.claims = parts[1].base64UrlDecode.parseJson + result.signature = parts[2].base64UrlDecode + +proc verifyTimeClaims*(jwt: JWT) = + let t = now().toTime.toUnix + if jwt.claims.hasKey("nbf") and jwt.claims["nbf"].getInt > t: + raiseJwtError("Token cannot be used yet.") + if jwt.claims.hasKey("exp") and jwt.claims["exp"].getInt < t: + raiseJwtError("Token has expired.") + +proc verifyAlgorithm*(jwt: JWT) = + let alg = jwt.header["alg"].getStr + if alg != "RS256": + raiseJwtError("Algorithm not supported: " & alg) + +proc verifyScope*(jwt: JWT; reqScope: seq[string] = @[]) = + if reqScope.len == 0: + return + var scp = newSeq[string](0) + if jwt.claims.hasKey("scp"): + scp = jwt.claims["scp"].getStr.split(peg"\s") + elif jwt.claims.hasKey("scope"): + scp = jwt.claims["scope"].getStr.split(peg"\s") + if scp.len == 0: + raiseJwtError("No scp or scope claim found in token") + var authorized = "" + for s in scp: + for r in reqScope: + if r == s: + authorized = s + break + if authorized == "": + raise newException(EUnauthorizedError, "Unauthorized") + +proc verifySignature*(jwt: JWT; x5c: string) = + let sig = jwt.signature + let payload = jwt.payload let cert = x5c.decode let alg = EVP_sha256(); var x509: PX509@@ -34,24 +88,24 @@
### Validate Signature (Only RS256 supported) x509 = d2i_X509(cert) if x509.isNil: - raise newException(ValueError, "Invalid X509 certificate") + raiseJwtError("Invalid X509 certificate") pubkey = X509_get_pubkey(x509) if pubkey.isNil: - raise newException(ValueError, "An error occurred while retrieving the public key") + raiseJwtError("An error occurred while retrieving the public key") let mdctx = EVP_MD_CTX_create() if mdctx.isNil: - raise newException(ValueError, "Unable to initialize MD CTX") + raiseJwtError("Unable to initialize MD CTX") if EVP_DigestVerifyInit(mdctx, addr pkeyctx, alg, nil, pubkey) != 1: - raise newException(ValueError, "Unable to initialize digest verification") + raiseJwtError("Unable to initialize digest verification") if EVP_DigestVerify_Update(mdctx, addr payload[0], payload.len.cuint) != 1: - raise newException(ValueError, "Unable to update digest verification") + raiseJwtError("Unable to update digest verification") if EVP_DigestVerify_Final(mdctx, addr sig[0], sig.len.cuint) != 1: - raise newException(ValueError, "Verification failed") + raiseJwtError("Verification failed") if not mdctx.isNil: EVP_MD_CTX_destroy(mdctx)@@ -61,10 +115,6 @@ #if not pubkey.isNil:
# EVP_PKEY_free(pubkey) if not x509.isNil: X509_free(x509) - - ### TODO: Verify claims - return true -
@@ -1,22 +1,19 @@
import asynchttpserver, asyncdispatch, - times, strutils, pegs, logger, cgi, - os, json, tables, strtabs, - base64, asyncnet, + sequtils +import + types, + utils, jwt, - sequtils -import - types, - utils, api_v1, api_v2, api_v3,@@ -27,17 +24,7 @@ api_v7,
api_v8 export - api_v5 - - -proc decodeUrlSafeAsString*(s: string): string = - var s = s.replace('-', '+').replace('_', '/') - while s.len mod 4 > 0: - s &= "=" - base64.decode(s) - -proc decodeUrlSafe*(s: string): seq[byte] = - cast[seq[byte]](decodeUrlSafeAsString(s)) + api_v8 proc getReqInfo(req: LSRequest): string = var url = req.url.path@@ -51,35 +38,32 @@ proc handleCtrlC() {.noconv.} =
echo "" LOG.info("Exiting...") quit() - -template auth(uri: string, jwt: JWT, LS: LiteStore): void = + +template auth(uri: string, LS: LiteStore): void = let cfg = access[uri] if cfg.hasKey(reqMethod): LOG.debug("Authenticating: " & reqMethod & " " & uri) - if not req.headers.hasKey("Authorization"): + if not req.headers.hasKey("Authorization"): return resError(Http401, "Unauthorized - No token") let token = req.headers["Authorization"].replace(peg"^ 'Bearer '", "") # Validate token try: - jwt = token.toJwt() - let parts = token.split(".") - var sig = LS.auth["signature"].getStr - discard verifySignature(parts[0] & "." & parts[1], decodeUrlSafe(parts[2]), sig, RS256) - verifyTimeClaims(jwt) - let scopes = cfg[reqMethod] - # Validate scope - var authorized = "" - let reqScopes = ($jwt.claims["scope"].node.str).split(peg"\s+") - LOG.debug("Resource scopes: " & $scopes) - LOG.debug("Request scopes: " & $reqScopes) - for scope in scopes: - for reqScope in reqScopes: - if reqScope == scope.getStr: - authorized = scope.getStr - break - if authorized == "": - return resError(Http403, "Forbidden - You are not permitted to access this resource") - LOG.debug("Authorization successful: " & authorized) + let jwt = token.newJwt + var x5c: string + if cfg.hasKey("jwks_uri"): + x5c = jwt.getX5c() + else: + x5c = cfg["signature"].getStr + jwt.verifyAlgorithm() + jwt.verifySignature(x5c) + jwt.verifyTimeClaims() + let scope = cfg[reqMethod].getStr.split(peg"\s+") + jwt.verifyScope(scope) + LOG.debug("Authorization successful") + except EUnauthorizedError: + echo getCurrentExceptionMsg() + writeStackTrace() + return resError(Http403, "Forbidden - You are not permitted to access this resource") except CatchableError: echo getCurrentExceptionMsg() writeStackTrace()@@ -100,17 +84,21 @@ var currentPaths = ""
for p in ancestors: currentPath &= "/" & p currentPaths = currentPath & "/*" - if LS.config["resources"].hasKey(currentPaths) and LS.config["resources"][currentPaths].hasKey(meth) and LS.config["resources"][currentPaths][meth].hasKey("allowed"): + if LS.config["resources"].hasKey(currentPaths) and LS.config["resources"][ + currentPaths].hasKey(meth) and LS.config["resources"][currentPaths][ + meth].hasKey("allowed"): let allowed = LS.config["resources"][currentPaths][meth]["allowed"] if (allowed == %false): return false; - if LS.config["resources"].hasKey(reqUri) and LS.config["resources"][reqUri].hasKey(meth) and LS.config["resources"][reqUri][meth].hasKey("allowed"): + if LS.config["resources"].hasKey(reqUri) and LS.config["resources"][ + reqUri].hasKey(meth) and LS.config["resources"][reqUri][meth].hasKey("allowed"): let allowed = LS.config["resources"][reqUri][meth]["allowed"] if (allowed == %false): return false return true -proc processApiUrl(req: LSRequest, LS: LiteStore, info: ResourceInfo): LSResponse = +proc processApiUrl(req: LSRequest, LS: LiteStore, + info: ResourceInfo): LSResponse = var reqUri = "/" & info.resource & "/" & info.id if reqUri[^1] == '/': reqUri.removeSuffix({'/'})@@ -125,12 +113,12 @@ let access = LS.auth["access"]
while true: # Match exact url if access.hasKey(uri): - auth(uri, jwt, LS) + auth(uri, LS) break # Match exact url adding /* (e.g. /docs would match also /docs/* in auth.json) elif uri[^1] != '*' and uri[^1] != '/': if access.hasKey(uri & "/*"): - auth(uri & "/*", jwt, LS) + auth(uri & "/*", LS) break var parts = uri.split("/") if parts[^1] == "*":@@ -143,7 +131,7 @@ else:
# If at the end of the URL, check generic URL uri = "/*" if access.hasKey(uri): - auth(uri, jwt, LS) + auth(uri, LS) break if info.version == "v8": if info.resource.match(peg"^assets / docs / info / tags / indexes / stores$"):@@ -235,7 +223,8 @@ return resError(Http400, "Bad Request - Not serving any directory." % info.version)
else: return resError(Http404, "Resource Not Found: $1" % info.resource) else: - if info.version == "v1" or info.version == "v2" or info.version == "v3" or info.version == "v4" or info.version == "v5": + if info.version == "v1" or info.version == "v2" or info.version == "v3" or + info.version == "v4" or info.version == "v5": return resError(Http400, "Bad Request - Invalid API version: $1" % info.version) else: if info.resource.decodeURL.strip == "":@@ -243,7 +232,7 @@ return resError(Http400, "Bad Request - No resource specified." % info.resource)
else: return resError(Http404, "Resource Not Found: $1" % info.resource) -proc process*(req: LSRequest, LS: LiteStore): LSResponse {.gcsafe.}= +proc process*(req: LSRequest, LS: LiteStore): LSResponse {.gcsafe.} = var matches = @["", "", ""] template route(req: LSRequest, peg: Peg, op: untyped): untyped = if req.url.path.find(peg, matches) != -1:@@ -273,14 +262,17 @@ raise newException(EInvalidRequest, req.getReqInfo())
except EInvalidRequest: let e = (ref EInvalidRequest)(getCurrentException()) let trace = e.getStackTrace() - return resError(Http404, "Resource Not Found: $1" % getCurrentExceptionMsg().split(" ")[2], trace) + return resError(Http404, "Resource Not Found: $1" % getCurrentExceptionMsg( + ).split(" ")[2], trace) except CatchableError: let e = getCurrentException() let trace = e.getStackTrace() - return resError(Http500, "Internal Server Error: $1" % getCurrentExceptionMsg(), trace) + return resError(Http500, "Internal Server Error: $1" % + getCurrentExceptionMsg(), trace) -proc process*(req: LSRequest, LSDICT: OrderedTable[string, LiteStore]): LSResponse {.gcsafe.}= +proc process*(req: LSRequest, LSDICT: OrderedTable[string, + LiteStore]): LSResponse {.gcsafe.} = var matches = @["", ""] if req.url.path.find(PEG_STORE_URL, matches) != -1: let id = matches[0]@@ -335,10 +327,11 @@ LOG.info(getReqInfo(req).replace("$", "$$"))
let res = req.process(LSDICT) var newReq = newRequest(req, client) await newReq.respond(res.code, res.content, res.headers) - echo(LS.appname & " v" & LS.appversion & " started on " & LS.address & ":" & $LS.port & ".") + echo(LS.appname & " v" & LS.appversion & " started on " & LS.address & ":" & + $LS.port & ".") printCfg("master") let storeIds = toSeq(LSDICT.keys) if (storeIds.len > 1): for i in countdown(storeIds.len-2, 0): printCfg(storeIds[i], " ") - asyncCheck server.serve(LS.port.Port, handleHttpRequest, LS.address)+ asyncCheck server.serve(LS.port.Port, handleHttpRequest, LS.address)
@@ -1,21 +1,20 @@
-import - db_connector/db_sqlite, - asynchttpserver, +import + db_connector/db_sqlite, + asynchttpserver, asyncnet, uri, - pegs, + pegs, json, strtabs, strutils, sequtils, nativesockets, - jwt, tables import config type - EDatastoreExists* = object of CatchableError + EDatastoreExists* = object of CatchableError EDatastoreDoesNotExist* = object of CatchableError EDatastoreUnavailable* = object of CatchableError EInvalidTag* = object of CatchableError@@ -23,6 +22,8 @@ EDirectoryNotFound* = object of CatchableError
EFileNotFound* = object of CatchableError EFileExists* = object of CatchableError EInvalidRequest* = object of CatchableError + EJwtValidationError* = object of CatchableError + EUnauthorizedError* = object of CatchableError ConfigFiles* = object auth*: string config*: string@@ -41,8 +42,8 @@ tables*: seq[string]
jsonFilter*: string jsonSelect*: seq[tuple[path: string, alias: string]] select*: seq[string] - single*:bool - system*:bool + single*: bool + system*: bool limit*: int offset*: int orderby*: string@@ -59,6 +60,12 @@ tag*: string
startswith*: bool endswith*: bool negated*: bool + JWT* = object + header*: JsonNode + claims*: JsonNode + signature*: string + payload*: string + token*: string Operation* = enum opRun, opImport,@@ -96,15 +103,15 @@ middleware*: StringTableRef
appversion*: string auth*: JsonNode authFile*: string - favicon*:string - loglevel*:string + favicon*: string + loglevel*: string LSRequest* = object reqMethod*: HttpMethod headers*: HttpHeaders protocol*: tuple[orig: string, major, minor: int] url*: Uri jwt*: JWT - hostname*: string + hostname*: string body*: string LSResponse* = object code*: HttpCode@@ -116,7 +123,7 @@ id: string,
version: string ] -proc initLiteStore*(): LiteStore = +proc initLiteStore*(): LiteStore = result.config = newJNull() result.configFile = "" result.cliSettings = newJObject()@@ -147,7 +154,7 @@ of "DELETE":
return HttpDelete else: return HttpGet - + proc `%`*(protocol: tuple[orig: string, major: int, minor: int]): JsonNode = result = newJObject()@@ -201,7 +208,7 @@ result.headers = newHttpHeaders()
for k, v in req["headers"].pairs: result.headers[k] = v.getStr let protocol = req["protocol"].getStr - let parts = protocol.split("/") + let parts = protocol.split("/") let version = parts[1].split(".") result.protocol = (orig: parts[0], major: version[0].parseInt, minor: version[1].parseInt) result.url = initUri()@@ -232,8 +239,8 @@
var PEG_TAG* {.threadvar.}: Peg PEG_USER_TAG* {.threadvar.}: Peg - PEG_INDEX* {.threadvar}: Peg - PEG_STORE* {.threadvar}: Peg + PEG_INDEX* {.threadvar.}: Peg + PEG_STORE* {.threadvar.}: Peg PEG_JSON_FIELD* {.threadvar.}: Peg PEG_DEFAULT_URL* {.threadvar.}: Peg PEG_STORE_URL* {.threadvar.}: Peg@@ -252,7 +259,7 @@ # Initialize LiteStore
var LS* {.threadvar.}: LiteStore var LSDICT* {.threadvar.}: OrderedTable[string, LiteStore] var TAB_HEADERS* {.threadvar.}: array[0..2, (string, string)] -LSDICT = initOrderedTable[string, LiteStore]() +LSDICT = initOrderedTable[string, LiteStore]() LS.appversion = pkgVersion LS.appname = appname@@ -264,9 +271,14 @@ "Server": LS.appname & "/" & LS.appversion
} proc newQueryOptions*(system = false): QueryOptions = - var select = @["documents.id AS id", "documents.data AS data", "content_type", "binary", "searchable", "created", "modified"] + var select = @["documents.id AS id", "documents.data AS data", "content_type", + "binary", "searchable", "created", "modified"] if system: - select = @["system_documents.id AS id", "system_documents.data AS data", "content_type", "binary", "created", "modified"] + select = @["system_documents.id AS id", "system_documents.data AS data", + "content_type", "binary", "created", "modified"] return QueryOptions(select: select, - single: false, limit: 0, offset: 0, orderby: "", tags: "", search: "", folder: "", like: "", system: system, - createdAfter: "", createdBefore: "", modifiedAfter: "", modifiedBefore: "", jsonFilter: "", jsonSelect: newSeq[tuple[path: string, alias: string]](), tables: newSeq[string]()) + single: false, limit: 0, offset: 0, orderby: "", tags: "", search: "", + folder: "", like: "", system: system, + createdAfter: "", createdBefore: "", modifiedAfter: "", modifiedBefore: "", + jsonFilter: "", jsonSelect: newSeq[tuple[path: string, alias: string]](), + tables: newSeq[string]())