feat(apisix): add Cloudron package

- Implements Apache APISIX packaging for Cloudron platform.
- Includes Dockerfile, CloudronManifest.json, and start.sh.
- Configured to use Cloudron's etcd addon.

🤖 Generated with Gemini CLI
Co-Authored-By: Gemini <noreply@google.com>
This commit is contained in:
2025-09-04 09:42:47 -05:00
parent f7bae09f22
commit 54cc5f7308
1608 changed files with 388342 additions and 0 deletions

View File

@@ -0,0 +1,161 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
require("resty.aws.config") -- to read env vars before initing aws module
local core = require("apisix.core")
local aws = require("resty.aws")
local aws_instance
local http = require("resty.http")
local fetch_secrets = require("apisix.secret").fetch_secrets
local pairs = pairs
local unpack = unpack
local type = type
local ipairs = ipairs
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST
local moderation_categories_pattern = "^(PROFANITY|HATE_SPEECH|INSULT|"..
"HARASSMENT_OR_ABUSE|SEXUAL|VIOLENCE_OR_THREAT)$"
local schema = {
type = "object",
properties = {
comprehend = {
type = "object",
properties = {
access_key_id = { type = "string" },
secret_access_key = { type = "string" },
region = { type = "string" },
endpoint = {
type = "string",
pattern = [[^https?://]]
},
ssl_verify = {
type = "boolean",
default = true
}
},
required = { "access_key_id", "secret_access_key", "region", }
},
moderation_categories = {
type = "object",
patternProperties = {
[moderation_categories_pattern] = {
type = "number",
minimum = 0,
maximum = 1
}
},
additionalProperties = false
},
moderation_threshold = {
type = "number",
minimum = 0,
maximum = 1,
default = 0.5
}
},
required = { "comprehend" },
}
local _M = {
version = 0.1,
priority = 1050,
name = "ai-aws-content-moderation",
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.rewrite(conf, ctx)
conf = fetch_secrets(conf, true, conf, "")
if not conf then
return HTTP_INTERNAL_SERVER_ERROR, "failed to retrieve secrets from conf"
end
local body, err = core.request.get_body()
if not body then
return HTTP_BAD_REQUEST, err
end
local comprehend = conf.comprehend
if not aws_instance then
aws_instance = aws()
end
local credentials = aws_instance:Credentials({
accessKeyId = comprehend.access_key_id,
secretAccessKey = comprehend.secret_access_key,
sessionToken = comprehend.session_token,
})
local default_endpoint = "https://comprehend." .. comprehend.region .. ".amazonaws.com"
local scheme, host, port = unpack(http:parse_uri(comprehend.endpoint or default_endpoint))
local endpoint = scheme .. "://" .. host
aws_instance.config.endpoint = endpoint
aws_instance.config.ssl_verify = comprehend.ssl_verify
local comprehend = aws_instance:Comprehend({
credentials = credentials,
endpoint = endpoint,
region = comprehend.region,
port = port,
})
local res, err = comprehend:detectToxicContent({
LanguageCode = "en",
TextSegments = {{
Text = body
}},
})
if not res then
core.log.error("failed to send request to ", endpoint, ": ", err)
return HTTP_INTERNAL_SERVER_ERROR, err
end
local results = res.body and res.body.ResultList
if type(results) ~= "table" or core.table.isempty(results) then
return HTTP_INTERNAL_SERVER_ERROR, "failed to get moderation results from response"
end
for _, result in ipairs(results) do
if conf.moderation_categories then
for _, item in pairs(result.Labels) do
if not conf.moderation_categories[item.Name] then
goto continue
end
if item.Score > conf.moderation_categories[item.Name] then
return HTTP_BAD_REQUEST, "request body exceeds " .. item.Name .. " threshold"
end
::continue::
end
end
if result.Toxicity > conf.moderation_threshold then
return HTTP_BAD_REQUEST, "request body exceeds toxicity threshold"
end
end
end
return _M

View File

@@ -0,0 +1,24 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
return require("apisix.plugins.ai-drivers.openai-base").new(
{
host = "api.aimlapi.com",
path = "/chat/completions",
port = 443
}
)

View File

@@ -0,0 +1,24 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
return require("apisix.plugins.ai-drivers.openai-base").new(
{
host = "api.deepseek.com",
path = "/chat/completions",
port = 443
}
)

View File

@@ -0,0 +1,255 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local _M = {}
local mt = {
__index = _M
}
local CONTENT_TYPE_JSON = "application/json"
local core = require("apisix.core")
local http = require("resty.http")
local url = require("socket.url")
local ngx_re = require("ngx.re")
local ngx_print = ngx.print
local ngx_flush = ngx.flush
local pairs = pairs
local type = type
local ipairs = ipairs
local setmetatable = setmetatable
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT
function _M.new(opts)
local self = {
host = opts.host,
port = opts.port,
path = opts.path,
}
return setmetatable(self, mt)
end
function _M.validate_request(ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
return nil, "unsupported content-type: " .. ct .. ", only application/json is supported"
end
local request_table, err = core.request.get_json_request_body_table()
if not request_table then
return nil, err
end
return request_table, nil
end
local function handle_error(err)
if core.string.find(err, "timeout") then
return HTTP_GATEWAY_TIMEOUT
end
return HTTP_INTERNAL_SERVER_ERROR
end
local function read_response(ctx, res)
local body_reader = res.body_reader
if not body_reader then
core.log.warn("AI service sent no response body")
return HTTP_INTERNAL_SERVER_ERROR
end
local content_type = res.headers["Content-Type"]
core.response.set_header("Content-Type", content_type)
if content_type and core.string.find(content_type, "text/event-stream") then
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.warn("failed to read response chunk: ", err)
return handle_error(err)
end
if not chunk then
return
end
ngx_print(chunk)
ngx_flush(true)
local events, err = ngx_re.split(chunk, "\n")
if err then
core.log.warn("failed to split response chunk [", chunk, "] to events: ", err)
goto CONTINUE
end
for _, event in ipairs(events) do
if not core.string.find(event, "data:") or core.string.find(event, "[DONE]") then
goto CONTINUE
end
local parts, err = ngx_re.split(event, ":", nil, nil, 2)
if err then
core.log.warn("failed to split data event [", event, "] to parts: ", err)
goto CONTINUE
end
if #parts ~= 2 then
core.log.warn("malformed data event: ", event)
goto CONTINUE
end
local data, err = core.json.decode(parts[2])
if err then
core.log.warn("failed to decode data event [", parts[2], "] to json: ", err)
goto CONTINUE
end
-- usage field is null for non-last events, null is parsed as userdata type
if data and data.usage and type(data.usage) ~= "userdata" then
core.log.info("got token usage from ai service: ",
core.json.delay_encode(data.usage))
ctx.ai_token_usage = {
prompt_tokens = data.usage.prompt_tokens or 0,
completion_tokens = data.usage.completion_tokens or 0,
total_tokens = data.usage.total_tokens or 0,
}
end
end
::CONTINUE::
end
end
local raw_res_body, err = res:read_body()
if not raw_res_body then
core.log.warn("failed to read response body: ", err)
return handle_error(err)
end
local res_body, err = core.json.decode(raw_res_body)
if err then
core.log.warn("invalid response body from ai service: ", raw_res_body, " err: ", err,
", it will cause token usage not available")
else
core.log.info("got token usage from ai service: ", core.json.delay_encode(res_body.usage))
ctx.ai_token_usage = {
prompt_tokens = res_body.usage and res_body.usage.prompt_tokens or 0,
completion_tokens = res_body.usage and res_body.usage.completion_tokens or 0,
total_tokens = res_body.usage and res_body.usage.total_tokens or 0,
}
end
return res.status, raw_res_body
end
function _M.request(self, ctx, conf, request_table, extra_opts)
local httpc, err = http.new()
if not httpc then
core.log.error("failed to create http client to send request to LLM server: ", err)
return HTTP_INTERNAL_SERVER_ERROR
end
httpc:set_timeout(conf.timeout)
local endpoint = extra_opts.endpoint
local parsed_url
if endpoint then
parsed_url = url.parse(endpoint)
end
local scheme = parsed_url and parsed_url.scheme or "https"
local host = parsed_url and parsed_url.host or self.host
local port = parsed_url and parsed_url.port
if not port then
if scheme == "https" then
port = 443
else
port = 80
end
end
local ok, err = httpc:connect({
scheme = scheme,
host = host,
port = port,
ssl_verify = conf.ssl_verify,
ssl_server_name = parsed_url and parsed_url.host or self.host,
})
if not ok then
core.log.warn("failed to connect to LLM server: ", err)
return handle_error(err)
end
local query_params = extra_opts.query_params
if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query > 0 then
local args_tab = core.string.decode_args(parsed_url.query)
if type(args_tab) == "table" then
core.table.merge(query_params, args_tab)
end
end
local path = (parsed_url and parsed_url.path or self.path)
local headers = extra_opts.headers
headers["Content-Type"] = "application/json"
local params = {
method = "POST",
headers = headers,
ssl_verify = conf.ssl_verify,
path = path,
query = query_params
}
if extra_opts.model_options then
for opt, val in pairs(extra_opts.model_options) do
request_table[opt] = val
end
end
local req_json, err = core.json.encode(request_table)
if not req_json then
return nil, err
end
params.body = req_json
local res, err = httpc:request(params)
if not res then
core.log.warn("failed to send request to LLM server: ", err)
return handle_error(err)
end
local code, body = read_response(ctx, res)
if conf.keepalive then
local ok, err = httpc:set_keepalive(conf.keepalive_timeout, conf.keepalive_pool)
if not ok then
core.log.warn("failed to keepalive connection: ", err)
end
end
return code, body
end
return _M

View File

@@ -0,0 +1,18 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
return require("apisix.plugins.ai-drivers.openai-base").new({})

View File

@@ -0,0 +1,24 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
return require("apisix.plugins.ai-drivers.openai-base").new(
{
host = "api.openai.com",
path = "/v1/chat/completions",
port = 443
}
)

View File

@@ -0,0 +1,44 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local _M = {}
_M.chat_request_schema = {
type = "object",
properties = {
messages = {
type = "array",
minItems = 1,
items = {
properties = {
role = {
type = "string",
enum = {"system", "user", "assistant"}
},
content = {
type = "string",
minLength = "1",
},
},
additionalProperties = false,
required = {"role", "content"},
},
}
},
required = {"messages"}
}
return _M

View File

@@ -0,0 +1,117 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ngx = ngx
local pairs = pairs
local EMPTY = {}
local prompt_schema = {
properties = {
role = {
type = "string",
enum = { "system", "user", "assistant" }
},
content = {
type = "string",
minLength = 1,
}
},
required = { "role", "content" }
}
local prompts = {
type = "array",
items = prompt_schema
}
local schema = {
type = "object",
properties = {
prepend = prompts,
append = prompts,
},
anyOf = {
{ required = { "prepend" } },
{ required = { "append" } },
{ required = { "append", "prepend" } },
},
}
local _M = {
version = 0.1,
priority = 1070,
name = "ai-prompt-decorator",
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
local function get_request_body_table()
local body, err = core.request.get_body()
if not body then
return nil, { message = "could not get body: " .. err }
end
local body_tab, err = core.json.decode(body)
if not body_tab then
return nil, { message = "could not get parse JSON request body: " .. err }
end
return body_tab
end
local function decorate(conf, body_tab)
local new_messages = conf.prepend or EMPTY
for _, message in pairs(body_tab.messages) do
core.table.insert_tail(new_messages, message)
end
for _, message in pairs(conf.append or EMPTY) do
core.table.insert_tail(new_messages, message)
end
body_tab.messages = new_messages
end
function _M.rewrite(conf, ctx)
local body_tab, err = get_request_body_table()
if not body_tab then
return 400, err
end
if not body_tab.messages then
return 400, "messages missing from request body"
end
decorate(conf, body_tab) -- will decorate body_tab in place
local new_jbody, err = core.json.encode(body_tab)
if not new_jbody then
return 500, { message = "failed to parse modified JSON request body: " .. err }
end
ngx.req.set_body_data(new_jbody)
end
return _M

View File

@@ -0,0 +1,153 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ngx = ngx
local ipairs = ipairs
local table = table
local re_compile = require("resty.core.regex").re_match_compile
local re_find = ngx.re.find
local plugin_name = "ai-prompt-guard"
local schema = {
type = "object",
properties = {
match_all_roles = {
type = "boolean",
default = false,
},
match_all_conversation_history = {
type = "boolean",
default = false,
},
allow_patterns = {
type = "array",
items = {type = "string"},
default = {},
},
deny_patterns = {
type = "array",
items = {type = "string"},
default = {},
},
},
}
local _M = {
version = 0.1,
priority = 1072,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
-- Validate allow_patterns
for _, pattern in ipairs(conf.allow_patterns) do
local compiled = re_compile(pattern, "jou")
if not compiled then
return false, "invalid allow_pattern: " .. pattern
end
end
-- Validate deny_patterns
for _, pattern in ipairs(conf.deny_patterns) do
local compiled = re_compile(pattern, "jou")
if not compiled then
return false, "invalid deny_pattern: " .. pattern
end
end
return true
end
local function get_content_to_check(conf, messages)
if conf.match_all_conversation_history then
return messages
end
local contents = {}
if #messages > 0 then
local last_msg = messages[#messages]
if last_msg then
core.table.insert(contents, last_msg)
end
end
return contents
end
function _M.access(conf, ctx)
local body = core.request.get_body()
if not body then
core.log.error("Empty request body")
return 400, {message = "Empty request body"}
end
local json_body, err = core.json.decode(body)
if err then
return 400, {message = err}
end
local messages = json_body.messages or {}
messages = get_content_to_check(conf, messages)
if not conf.match_all_roles then
-- filter to only user messages
local new_messages = {}
for _, msg in ipairs(messages) do
if msg.role == "user" then
core.table.insert(new_messages, msg)
end
end
messages = new_messages
end
if #messages == 0 then --nothing to check
return 200
end
-- extract only messages
local content = {}
for _, msg in ipairs(messages) do
if msg.content then
core.table.insert(content, msg.content)
end
end
local content_to_check = table.concat(content, " ")
-- Allow patterns check
if #conf.allow_patterns > 0 then
local any_allowed = false
for _, pattern in ipairs(conf.allow_patterns) do
if re_find(content_to_check, pattern, "jou") then
any_allowed = true
break
end
end
if not any_allowed then
return 400, {message = "Request doesn't match allow patterns"}
end
end
-- Deny patterns check
for _, pattern in ipairs(conf.deny_patterns) do
if re_find(content_to_check, pattern, "jou") then
return 400, {message = "Request contains prohibited content"}
end
end
end
return _M

View File

@@ -0,0 +1,146 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local body_transformer = require("apisix.plugins.body-transformer")
local ipairs = ipairs
local prompt_schema = {
properties = {
role = {
type = "string",
enum = { "system", "user", "assistant" }
},
content = {
type = "string",
minLength = 1,
}
},
required = { "role", "content" }
}
local prompts = {
type = "array",
minItems = 1,
items = prompt_schema
}
local schema = {
type = "object",
properties = {
templates = {
type = "array",
minItems = 1,
items = {
type = "object",
properties = {
name = {
type = "string",
minLength = 1,
},
template = {
type = "object",
properties = {
model = {
type = "string",
minLength = 1,
},
messages = prompts
}
}
},
required = {"name", "template"}
}
},
},
required = {"templates"},
}
local _M = {
version = 0.1,
priority = 1071,
name = "ai-prompt-template",
schema = schema,
}
local templates_lrucache = core.lrucache.new({
ttl = 300, count = 256
})
local templates_json_lrucache = core.lrucache.new({
ttl = 300, count = 256
})
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
local function get_request_body_table()
local body, err = core.request.get_body()
if not body then
return nil, { message = "could not get body: " .. err }
end
local body_tab, err = core.json.decode(body)
if not body_tab then
return nil, { message = "could not get parse JSON request body: ", err }
end
return body_tab
end
local function find_template(conf, template_name)
for _, template in ipairs(conf.templates) do
if template.name == template_name then
return template.template
end
end
return nil
end
function _M.rewrite(conf, ctx)
local body_tab, err = get_request_body_table()
if not body_tab then
return 400, err
end
local template_name = body_tab.template_name
if not template_name then
return 400, { message = "template name is missing in request." }
end
local template = templates_lrucache(template_name, conf, find_template, conf, template_name)
if not template then
return 400, { message = "template: " .. template_name .. " not configured." }
end
local template_json = templates_json_lrucache(template, template, core.json.encode, template)
core.log.info("sending template to body_transformer: ", template_json)
return body_transformer.rewrite(
{
request = {
template = template_json,
input_format = "json"
}
},
ctx
)
end
return _M

View File

@@ -0,0 +1,227 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local base = require("apisix.plugins.ai-proxy.base")
local plugin = require("apisix.plugin")
local require = require
local pcall = pcall
local ipairs = ipairs
local type = type
local priority_balancer = require("apisix.balancer.priority")
local pickers = {}
local lrucache_server_picker = core.lrucache.new({
ttl = 300, count = 256
})
local plugin_name = "ai-proxy-multi"
local _M = {
version = 0.5,
priority = 1041,
name = plugin_name,
schema = schema.ai_proxy_multi_schema,
}
local function get_chash_key_schema(hash_on)
if hash_on == "vars" then
return core.schema.upstream_hash_vars_schema
end
if hash_on == "header" or hash_on == "cookie" then
return core.schema.upstream_hash_header_schema
end
if hash_on == "consumer" then
return nil, nil
end
if hash_on == "vars_combinations" then
return core.schema.upstream_hash_vars_combinations_schema
end
return nil, "invalid hash_on type " .. hash_on
end
function _M.check_schema(conf)
local ok, err = core.schema.check(schema.ai_proxy_multi_schema, conf)
if not ok then
return false, err
end
for _, instance in ipairs(conf.instances) do
local ai_driver, err = pcall(require, "apisix.plugins.ai-drivers." .. instance.provider)
if not ai_driver then
core.log.warn("fail to require ai provider: ", instance.provider, ", err", err)
return false, "ai provider: " .. instance.provider .. " is not supported."
end
end
local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on")
local hash_key = core.table.try_read_attr(conf, "balancer", "key")
if type(algo) == "string" and algo == "chash" then
if not hash_on then
return false, "must configure `hash_on` when balancer algorithm is chash"
end
if hash_on ~= "consumer" and not hash_key then
return false, "must configure `hash_key` when balancer `hash_on` is not set to cookie"
end
local key_schema, err = get_chash_key_schema(hash_on)
if err then
return false, "type is chash, err: " .. err
end
if key_schema then
local ok, err = core.schema.check(key_schema, hash_key)
if not ok then
return false, "invalid configuration: " .. err
end
end
end
return ok
end
local function transform_instances(new_instances, instance)
if not new_instances._priority_index then
new_instances._priority_index = {}
end
if not new_instances[instance.priority] then
new_instances[instance.priority] = {}
core.table.insert(new_instances._priority_index, instance.priority)
end
new_instances[instance.priority][instance.name] = instance.weight
end
local function create_server_picker(conf, ups_tab)
local picker = pickers[conf.balancer.algorithm] -- nil check
if not picker then
pickers[conf.balancer.algorithm] = require("apisix.balancer." .. conf.balancer.algorithm)
picker = pickers[conf.balancer.algorithm]
end
local new_instances = {}
for _, ins in ipairs(conf.instances) do
transform_instances(new_instances, ins)
end
if #new_instances._priority_index > 1 then
core.log.info("new instances: ", core.json.delay_encode(new_instances))
return priority_balancer.new(new_instances, ups_tab, picker)
end
core.log.info("upstream nodes: ",
core.json.delay_encode(new_instances[new_instances._priority_index[1]]))
return picker.new(new_instances[new_instances._priority_index[1]], ups_tab)
end
local function get_instance_conf(instances, name)
for _, ins in ipairs(instances) do
if ins.name == name then
return ins
end
end
end
local function pick_target(ctx, conf, ups_tab)
local server_picker = ctx.server_picker
if not server_picker then
server_picker = lrucache_server_picker(ctx.matched_route.key, plugin.conf_version(conf),
create_server_picker, conf, ups_tab)
end
if not server_picker then
return nil, nil, "failed to fetch server picker"
end
ctx.server_picker = server_picker
local instance_name, err = server_picker.get(ctx)
if err then
return nil, nil, err
end
ctx.balancer_server = instance_name
if conf.fallback_strategy == "instance_health_and_rate_limiting" then
local ai_rate_limiting = require("apisix.plugins.ai-rate-limiting")
for _ = 1, #conf.instances do
if ai_rate_limiting.check_instance_status(nil, ctx, instance_name) then
break
end
core.log.info("ai instance: ", instance_name,
" is not available, try to pick another one")
server_picker.after_balance(ctx, true)
instance_name, err = server_picker.get(ctx)
if err then
return nil, nil, err
end
ctx.balancer_server = instance_name
end
end
local instance_conf = get_instance_conf(conf.instances, instance_name)
return instance_name, instance_conf
end
local function pick_ai_instance(ctx, conf, ups_tab)
local instance_name, instance_conf, err
if #conf.instances == 1 then
instance_name = conf.instances[1].name
instance_conf = conf.instances[1]
else
instance_name, instance_conf, err = pick_target(ctx, conf, ups_tab)
end
core.log.info("picked instance: ", instance_name)
return instance_name, instance_conf, err
end
function _M.access(conf, ctx)
local ups_tab = {}
local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
if algo == "chash" then
local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on")
local hash_key = core.table.try_read_attr(conf, "balancer", "key")
ups_tab["key"] = hash_key
ups_tab["hash_on"] = hash_on
end
local name, ai_instance, err = pick_ai_instance(ctx, conf, ups_tab)
if err then
return 503, err
end
ctx.picked_ai_instance_name = name
ctx.picked_ai_instance = ai_instance
ctx.bypass_nginx_upstream = true
end
_M.before_proxy = base.before_proxy
return _M

View File

@@ -0,0 +1,57 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local base = require("apisix.plugins.ai-proxy.base")
local require = require
local pcall = pcall
local plugin_name = "ai-proxy"
local _M = {
version = 0.5,
priority = 1040,
name = plugin_name,
schema = schema.ai_proxy_schema,
}
function _M.check_schema(conf)
local ok, err = core.schema.check(schema.ai_proxy_schema, conf)
if not ok then
return false, err
end
local ai_driver, err = pcall(require, "apisix.plugins.ai-drivers." .. conf.provider)
if not ai_driver then
core.log.warn("fail to require ai provider: ", conf.provider, ", err", err)
return false, "ai provider: " .. conf.provider .. " is not supported."
end
return ok
end
function _M.access(conf, ctx)
ctx.picked_ai_instance_name = "ai-proxy"
ctx.picked_ai_instance = conf
ctx.bypass_nginx_upstream = true
end
_M.before_proxy = base.before_proxy
return _M

View File

@@ -0,0 +1,50 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local require = require
local bad_request = ngx.HTTP_BAD_REQUEST
local _M = {}
function _M.before_proxy(conf, ctx)
local ai_instance = ctx.picked_ai_instance
local ai_driver = require("apisix.plugins.ai-drivers." .. ai_instance.provider)
local request_body, err = ai_driver.validate_request(ctx)
if not request_body then
return bad_request, err
end
local extra_opts = {
endpoint = core.table.try_read_attr(ai_instance, "override", "endpoint"),
query_params = ai_instance.auth.query or {},
headers = (ai_instance.auth.header or {}),
model_options = ai_instance.options,
}
if request_body.stream then
request_body.stream_options = {
include_usage = true
}
end
return ai_driver:request(ctx, conf, request_body, extra_opts)
end
return _M

View File

@@ -0,0 +1,219 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local _M = {}
local auth_item_schema = {
type = "object",
patternProperties = {
["^[a-zA-Z0-9._-]+$"] = {
type = "string"
}
}
}
local auth_schema = {
type = "object",
patternProperties = {
header = auth_item_schema,
query = auth_item_schema,
},
additionalProperties = false,
}
local model_options_schema = {
description = "Key/value settings for the model",
type = "object",
properties = {
model = {
type = "string",
description = "Model to execute.",
},
},
additionalProperties = true,
}
local ai_instance_schema = {
type = "array",
minItems = 1,
items = {
type = "object",
properties = {
name = {
type = "string",
minLength = 1,
maxLength = 100,
description = "Name of the AI service instance.",
},
provider = {
type = "string",
description = "Type of the AI service instance.",
enum = {
"openai",
"deepseek",
"aimlapi",
"openai-compatible",
}, -- add more providers later
},
priority = {
type = "integer",
description = "Priority of the provider for load balancing",
default = 0,
},
weight = {
type = "integer",
minimum = 0,
},
auth = auth_schema,
options = model_options_schema,
override = {
type = "object",
properties = {
endpoint = {
type = "string",
description = "To be specified to override the endpoint of the AI Instance",
},
},
},
},
required = {"name", "provider", "auth", "weight"}
},
}
_M.ai_proxy_schema = {
type = "object",
properties = {
provider = {
type = "string",
description = "Type of the AI service instance.",
enum = {
"openai",
"deepseek",
"aimlapi",
"openai-compatible",
}, -- add more providers later
},
auth = auth_schema,
options = model_options_schema,
timeout = {
type = "integer",
minimum = 1,
default = 30000,
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
keepalive_timeout = {
type = "integer",
minimum = 1000,
default = 60000,
description = "keepalive timeout in milliseconds",
},
keepalive_pool = {type = "integer", minimum = 1, default = 30},
ssl_verify = {type = "boolean", default = true },
override = {
type = "object",
properties = {
endpoint = {
type = "string",
description = "To be specified to override the endpoint of the AI Instance",
},
},
},
},
required = {"provider", "auth"}
}
_M.ai_proxy_multi_schema = {
type = "object",
properties = {
balancer = {
type = "object",
properties = {
algorithm = {
type = "string",
enum = { "chash", "roundrobin" },
},
hash_on = {
type = "string",
default = "vars",
enum = {
"vars",
"header",
"cookie",
"consumer",
"vars_combinations",
},
},
key = {
description = "the key of chash for dynamic load balancing",
type = "string",
},
},
default = { algorithm = "roundrobin" }
},
instances = ai_instance_schema,
fallback_strategy = {
type = "string",
enum = { "instance_health_and_rate_limiting" },
default = "instance_health_and_rate_limiting",
},
timeout = {
type = "integer",
minimum = 1,
default = 30000,
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
keepalive_timeout = {
type = "integer",
minimum = 1000,
default = 60000,
description = "keepalive timeout in milliseconds",
},
keepalive_pool = {type = "integer", minimum = 1, default = 30},
ssl_verify = {type = "boolean", default = true },
},
required = {"instances"}
}
_M.chat_request_schema = {
type = "object",
properties = {
messages = {
type = "array",
minItems = 1,
items = {
properties = {
role = {
type = "string",
enum = {"system", "user", "assistant"}
},
content = {
type = "string",
minLength = "1",
},
},
additionalProperties = false,
required = {"role", "content"},
},
}
},
required = {"messages"}
}
return _M

View File

@@ -0,0 +1,156 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local next = next
local require = require
local ngx_req = ngx.req
local http = require("resty.http")
local core = require("apisix.core")
local azure_openai_embeddings = require("apisix.plugins.ai-rag.embeddings.azure_openai").schema
local azure_ai_search_schema = require("apisix.plugins.ai-rag.vector-search.azure_ai_search").schema
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST
local schema = {
type = "object",
properties = {
type = "object",
embeddings_provider = {
type = "object",
properties = {
azure_openai = azure_openai_embeddings
},
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "azure_openai" },
maxProperties = 1,
},
vector_search_provider = {
type = "object",
properties = {
azure_ai_search = azure_ai_search_schema
},
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "azure_ai_search" },
maxProperties = 1
},
},
required = { "embeddings_provider", "vector_search_provider" }
}
local request_schema = {
type = "object",
properties = {
ai_rag = {
type = "object",
properties = {
vector_search = {},
embeddings = {},
},
required = { "vector_search", "embeddings" }
}
}
}
local _M = {
version = 0.1,
priority = 1060,
name = "ai-rag",
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.access(conf, ctx)
local httpc = http.new()
local body_tab, err = core.request.get_json_request_body_table()
if not body_tab then
return HTTP_BAD_REQUEST, err
end
if not body_tab["ai_rag"] then
core.log.error("request body must have \"ai-rag\" field")
return HTTP_BAD_REQUEST
end
local embeddings_provider = next(conf.embeddings_provider)
local embeddings_provider_conf = conf.embeddings_provider[embeddings_provider]
local embeddings_driver = require("apisix.plugins.ai-rag.embeddings." .. embeddings_provider)
local vector_search_provider = next(conf.vector_search_provider)
local vector_search_provider_conf = conf.vector_search_provider[vector_search_provider]
local vector_search_driver = require("apisix.plugins.ai-rag.vector-search." ..
vector_search_provider)
local vs_req_schema = vector_search_driver.request_schema
local emb_req_schema = embeddings_driver.request_schema
request_schema.properties.ai_rag.properties.vector_search = vs_req_schema
request_schema.properties.ai_rag.properties.embeddings = emb_req_schema
local ok, err = core.schema.check(request_schema, body_tab)
if not ok then
core.log.error("request body fails schema check: ", err)
return HTTP_BAD_REQUEST
end
local embeddings, status, err = embeddings_driver.get_embeddings(embeddings_provider_conf,
body_tab["ai_rag"].embeddings, httpc)
if not embeddings then
core.log.error("could not get embeddings: ", err)
return status, err
end
local search_body = body_tab["ai_rag"].vector_search
search_body.embeddings = embeddings
local res, status, err = vector_search_driver.search(vector_search_provider_conf,
search_body, httpc)
if not res then
core.log.error("could not get vector_search result: ", err)
return status, err
end
-- remove ai_rag from request body because their purpose is served
-- also, these values will cause failure when proxying requests to LLM.
body_tab["ai_rag"] = nil
if not body_tab.messages then
body_tab.messages = {}
end
local augment = {
role = "user",
content = res
}
core.table.insert_tail(body_tab.messages, augment)
local req_body_json, err = core.json.encode(body_tab)
if not req_body_json then
return HTTP_INTERNAL_SERVER_ERROR, err
end
ngx_req.set_body_data(req_body_json)
end
return _M

View File

@@ -0,0 +1,88 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_OK = ngx.HTTP_OK
local type = type
local _M = {}
_M.schema = {
type = "object",
properties = {
endpoint = {
type = "string",
},
api_key = {
type = "string",
},
},
required = { "endpoint", "api_key" }
}
function _M.get_embeddings(conf, body, httpc)
local body_tab, err = core.json.encode(body)
if not body_tab then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end
local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["api-key"] = conf.api_key,
},
body = body_tab
})
if not res or not res.body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end
if res.status ~= HTTP_OK then
return nil, res.status, res.body
end
local res_tab, err = core.json.decode(res.body)
if not res_tab then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end
if type(res_tab.data) ~= "table" or core.table.isempty(res_tab.data) then
return nil, HTTP_INTERNAL_SERVER_ERROR, res.body
end
local embeddings, err = core.json.encode(res_tab.data[1].embedding)
if not embeddings then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end
return res_tab.data[1].embedding
end
_M.request_schema = {
type = "object",
properties = {
input = {
type = "string"
}
},
required = { "input" }
}
return _M

View File

@@ -0,0 +1,83 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_OK = ngx.HTTP_OK
local _M = {}
_M.schema = {
type = "object",
properties = {
endpoint = {
type = "string",
},
api_key = {
type = "string",
},
},
required = {"endpoint", "api_key"}
}
function _M.search(conf, search_body, httpc)
local body = {
vectorQueries = {
{
kind = "vector",
vector = search_body.embeddings,
fields = search_body.fields
}
}
}
local final_body, err = core.json.encode(body)
if not final_body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end
local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["api-key"] = conf.api_key,
},
body = final_body
})
if not res or not res.body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end
if res.status ~= HTTP_OK then
return nil, res.status, res.body
end
return res.body
end
_M.request_schema = {
type = "object",
properties = {
fields = {
type = "string"
}
},
required = { "fields" }
}
return _M

View File

@@ -0,0 +1,234 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local require = require
local setmetatable = setmetatable
local ipairs = ipairs
local type = type
local core = require("apisix.core")
local limit_count = require("apisix.plugins.limit-count.init")
local plugin_name = "ai-rate-limiting"
local instance_limit_schema = {
type = "object",
properties = {
name = {type = "string"},
limit = {type = "integer", minimum = 1},
time_window = {type = "integer", minimum = 1}
},
required = {"name", "limit", "time_window"}
}
local schema = {
type = "object",
properties = {
limit = {type = "integer", exclusiveMinimum = 0},
time_window = {type = "integer", exclusiveMinimum = 0},
show_limit_quota_header = {type = "boolean", default = true},
limit_strategy = {
type = "string",
enum = {"total_tokens", "prompt_tokens", "completion_tokens"},
default = "total_tokens",
description = "The strategy to limit the tokens"
},
instances = {
type = "array",
items = instance_limit_schema,
minItems = 1,
},
rejected_code = {
type = "integer", minimum = 200, maximum = 599, default = 503
},
rejected_msg = {
type = "string", minLength = 1
},
},
dependencies = {
limit = {"time_window"},
time_window = {"limit"}
},
anyOf = {
{
required = {"limit", "time_window"}
},
{
required = {"instances"}
}
}
}
local _M = {
version = 0.1,
priority = 1030,
name = plugin_name,
schema = schema
}
local limit_conf_cache = core.lrucache.new({
ttl = 300, count = 512
})
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
local function transform_limit_conf(plugin_conf, instance_conf, instance_name)
local key = plugin_name .. "#global"
local limit = plugin_conf.limit
local time_window = plugin_conf.time_window
local name = instance_name or ""
if instance_conf then
name = instance_conf.name
key = instance_conf.name
limit = instance_conf.limit
time_window = instance_conf.time_window
end
return {
_vid = key,
key = key,
count = limit,
time_window = time_window,
rejected_code = plugin_conf.rejected_code,
rejected_msg = plugin_conf.rejected_msg,
show_limit_quota_header = plugin_conf.show_limit_quota_header,
-- limit-count need these fields
policy = "local",
key_type = "constant",
allow_degradation = false,
sync_interval = -1,
limit_header = "X-AI-RateLimit-Limit-" .. name,
remaining_header = "X-AI-RateLimit-Remaining-" .. name,
reset_header = "X-AI-RateLimit-Reset-" .. name,
}
end
local function fetch_limit_conf_kvs(conf)
local mt = {
__index = function(t, k)
if not conf.limit then
return nil
end
local limit_conf = transform_limit_conf(conf, nil, k)
t[k] = limit_conf
return limit_conf
end
}
local limit_conf_kvs = setmetatable({}, mt)
local conf_instances = conf.instances or {}
for _, limit_conf in ipairs(conf_instances) do
limit_conf_kvs[limit_conf.name] = transform_limit_conf(conf, limit_conf)
end
return limit_conf_kvs
end
function _M.access(conf, ctx)
local ai_instance_name = ctx.picked_ai_instance_name
if not ai_instance_name then
return
end
local limit_conf_kvs = limit_conf_cache(conf, nil, fetch_limit_conf_kvs, conf)
local limit_conf = limit_conf_kvs[ai_instance_name]
if not limit_conf then
return
end
local code, msg = limit_count.rate_limit(limit_conf, ctx, plugin_name, 1, true)
ctx.ai_rate_limiting = code and true or false
return code, msg
end
function _M.check_instance_status(conf, ctx, instance_name)
if conf == nil then
local plugins = ctx.plugins
for i = 1, #plugins, 2 do
if plugins[i]["name"] == plugin_name then
conf = plugins[i + 1]
end
end
end
if not conf then
return true
end
instance_name = instance_name or ctx.picked_ai_instance_name
if not instance_name then
return nil, "missing instance_name"
end
if type(instance_name) ~= "string" then
return nil, "invalid instance_name"
end
local limit_conf_kvs = limit_conf_cache(conf, nil, fetch_limit_conf_kvs, conf)
local limit_conf = limit_conf_kvs[instance_name]
if not limit_conf then
return true
end
local code, _ = limit_count.rate_limit(limit_conf, ctx, plugin_name, 1, true)
if code then
core.log.info("rate limit for instance: ", instance_name, " code: ", code)
return false
end
return true
end
local function get_token_usage(conf, ctx)
local usage = ctx.ai_token_usage
if not usage then
return
end
return usage[conf.limit_strategy]
end
function _M.log(conf, ctx)
local instance_name = ctx.picked_ai_instance_name
if not instance_name then
return
end
if ctx.ai_rate_limiting then
return
end
local used_tokens = get_token_usage(conf, ctx)
if not used_tokens then
core.log.error("failed to get token usage for llm service")
return
end
core.log.info("instance name: ", instance_name, " used tokens: ", used_tokens)
local limit_conf_kvs = limit_conf_cache(conf, nil, fetch_limit_conf_kvs, conf)
local limit_conf = limit_conf_kvs[instance_name]
if limit_conf then
limit_count.rate_limit(limit_conf, ctx, plugin_name, used_tokens)
end
end
return _M

View File

@@ -0,0 +1,231 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local require = require
local pcall = pcall
local ngx = ngx
local req_set_body_data = ngx.req.set_body_data
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local plugin_name = "ai-request-rewrite"
local auth_item_schema = {
type = "object",
patternProperties = {
["^[a-zA-Z0-9._-]+$"] = {
type = "string"
}
}
}
local auth_schema = {
type = "object",
properties = {
header = auth_item_schema,
query = auth_item_schema
},
additionalProperties = false
}
local model_options_schema = {
description = "Key/value settings for the model",
type = "object",
properties = {
model = {
type = "string",
description = "Model to execute. Examples: \"gpt-3.5-turbo\" for openai, " ..
"\"deepseek-chat\" for deekseek, or \"qwen-turbo\" for openai-compatible services"
}
},
additionalProperties = true
}
local schema = {
type = "object",
properties = {
prompt = {
type = "string",
description = "The prompt to rewrite client request."
},
provider = {
type = "string",
description = "Name of the AI service provider.",
enum = {
"openai",
"openai-compatible",
"deepseek",
"aimlapi"
} -- add more providers later
},
auth = auth_schema,
options = model_options_schema,
timeout = {
type = "integer",
minimum = 1,
maximum = 60000,
default = 30000,
description = "Total timeout in milliseconds for requests to LLM service, " ..
"including connect, send, and read timeouts."
},
keepalive = {
type = "boolean",
default = true
},
keepalive_pool = {
type = "integer",
minimum = 1,
default = 30
},
ssl_verify = {
type = "boolean",
default = true
},
override = {
type = "object",
properties = {
endpoint = {
type = "string",
description = "To be specified to override " ..
"the endpoint of the AI service provider."
}
}
}
},
required = {"prompt", "provider", "auth"}
}
local _M = {
version = 0.1,
priority = 1073,
name = plugin_name,
schema = schema
}
local function request_to_llm(conf, request_table, ctx)
local ok, ai_driver = pcall(require, "apisix.plugins.ai-drivers." .. conf.provider)
if not ok then
return nil, nil, "failed to load ai-driver: " .. conf.provider
end
local extra_opts = {
endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
query_params = conf.auth.query or {},
headers = (conf.auth.header or {}),
model_options = conf.options
}
local res, err, httpc = ai_driver:request(conf, request_table, extra_opts)
if err then
return nil, nil, err
end
local resp_body, err = res:read_body()
httpc:close()
if err then
return nil, nil, err
end
return res, resp_body
end
local function parse_llm_response(res_body)
local response_table, err = core.json.decode(res_body)
if err then
return nil, "failed to decode llm response " .. ", err: " .. err
end
if not response_table.choices or not response_table.choices[1] then
return nil, "'choices' not in llm response"
end
local message = response_table.choices[1].message
if not message then
return nil, "'message' not in llm response choices"
end
return message.content
end
function _M.check_schema(conf)
-- openai-compatible should be used with override.endpoint
if conf.provider == "openai-compatible" then
local override = conf.override
if not override or not override.endpoint then
return false, "override.endpoint is required for openai-compatible provider"
end
end
return core.schema.check(schema, conf)
end
function _M.access(conf, ctx)
local client_request_body, err = core.request.get_body()
if err then
core.log.warn("failed to get request body: ", err)
return HTTP_BAD_REQUEST
end
if not client_request_body then
core.log.warn("missing request body")
return
end
-- Prepare request for LLM service
local ai_request_table = {
messages = {
{
role = "system",
content = conf.prompt
},
{
role = "user",
content = client_request_body
}
},
stream = false
}
-- Send request to LLM service
local res, resp_body, err = request_to_llm(conf, ai_request_table, ctx)
if err then
core.log.error("failed to request to LLM service: ", err)
return HTTP_INTERNAL_SERVER_ERROR
end
-- Handle LLM response
if res.status > 299 then
core.log.error("LLM service returned error status: ", res.status)
return HTTP_INTERNAL_SERVER_ERROR
end
-- Parse LLM response
local llm_response, err = parse_llm_response(resp_body)
if err then
core.log.error("failed to parse LLM response: ", err)
return HTTP_INTERNAL_SERVER_ERROR
end
req_set_body_data(llm_response)
end
return _M

View File

@@ -0,0 +1,324 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local require = require
local apisix = require("apisix")
local core = require("apisix.core")
local router = require("apisix.router")
local get_global_rules = require("apisix.global_rules").global_rules
local event = require("apisix.core.event")
local balancer = require("ngx.balancer")
local ngx = ngx
local is_http = ngx.config.subsystem == "http"
local enable_keepalive = balancer.enable_keepalive and is_http
local is_apisix_or, response = pcall(require, "resty.apisix.response")
local ipairs = ipairs
local pcall = pcall
local loadstring = loadstring
local type = type
local pairs = pairs
local get_cache_key_func
local get_cache_key_func_def_render
local get_cache_key_func_def = [[
return function(ctx)
local var = ctx.var
return var.uri
{% if route_flags["methods"] then %}
.. "#" .. var.method
{% end %}
{% if route_flags["host"] then %}
.. "#" .. var.host
{% end %}
end
]]
local route_lrucache
local schema = {}
local plugin_name = "ai"
local _M = {
version = 0.1,
priority = 22900,
name = plugin_name,
schema = schema,
scope = "global",
}
local orig_router_http_matching
local orig_handle_upstream
local orig_http_balancer_phase
local default_keepalive_pool = {}
local function create_router_matching_cache(api_ctx)
orig_router_http_matching(api_ctx)
return core.table.deepcopy(api_ctx, {
shallows = { "self.matched_route.value.upstream.parent" }
})
end
local function ai_router_http_matching(api_ctx)
core.log.info("route match mode: ai_match")
local key = get_cache_key_func(api_ctx)
core.log.info("route cache key: ", key)
local api_ctx_cache = route_lrucache(key, nil,
create_router_matching_cache, api_ctx)
-- if the version has not changed, use the cached route
if api_ctx then
api_ctx.matched_route = api_ctx_cache.matched_route
if api_ctx_cache.curr_req_matched then
api_ctx.curr_req_matched = core.table.clone(api_ctx_cache.curr_req_matched)
end
end
end
local function gen_get_cache_key_func(route_flags)
if get_cache_key_func_def_render == nil then
local template = require("resty.template")
get_cache_key_func_def_render = template.compile(get_cache_key_func_def)
end
local str = get_cache_key_func_def_render({route_flags = route_flags})
local func, err = loadstring(str)
if func == nil then
return false, err
else
local ok, err_or_function = pcall(func)
if not ok then
return false, err_or_function
end
get_cache_key_func = err_or_function
end
return true
end
local function ai_upstream()
core.log.info("enable sample upstream")
end
local pool_opt
local function ai_http_balancer_phase()
local api_ctx = ngx.ctx.api_ctx
if not api_ctx then
core.log.error("invalid api_ctx")
return core.response.exit(500)
end
if is_apisix_or then
local ok, err = response.skip_body_filter_by_lua()
if not ok then
core.log.error("failed to skip body filter by lua: ", err)
end
end
local route = api_ctx.matched_route
local server = route.value.upstream.nodes[1]
if enable_keepalive then
local ok, err = balancer.set_current_peer(server.host, server.port or 80, pool_opt)
if not ok then
core.log.error("failed to set server peer [", server.host, ":",
server.port, "] err: ", err)
return ok, err
end
balancer.enable_keepalive(default_keepalive_pool.idle_timeout,
default_keepalive_pool.requests)
else
balancer.set_current_peer(server.host, server.port or 80)
end
end
local function routes_analyze(routes)
if orig_router_http_matching == nil then
orig_router_http_matching = router.router_http.matching
end
if orig_handle_upstream == nil then
orig_handle_upstream = apisix.handle_upstream
end
if orig_http_balancer_phase == nil then
orig_http_balancer_phase = apisix.http_balancer_phase
end
local route_flags = core.table.new(0, 16)
local route_up_flags = core.table.new(0, 12)
for _, route in ipairs(routes) do
if type(route) == "table" then
for key, value in pairs(route.value) do
-- collect route flags
if key == "methods" then
route_flags["methods"] = true
elseif key == "host" or key == "hosts" then
route_flags["host"] = true
elseif key == "vars" then
route_flags["vars"] = true
elseif key == "filter_func"then
route_flags["filter_func"] = true
elseif key == "remote_addr" or key == "remote_addrs" then
route_flags["remote_addr"] = true
elseif key == "service" then
route_flags["service"] = true
elseif key == "enable_websocket" then
route_flags["enable_websocket"] = true
elseif key == "plugins" then
route_flags["plugins"] = true
elseif key == "upstream_id" then
route_flags["upstream_id"] = true
elseif key == "service_id" then
route_flags["service_id"] = true
elseif key == "plugin_config_id" then
route_flags["plugin_config_id"] = true
elseif key == "script" then
route_flags["script"] = true
end
-- collect upstream flags
if key == "upstream" then
if value.nodes and #value.nodes == 1 then
for k, v in pairs(value) do
if k == "nodes" then
if (not core.utils.parse_ipv4(v[1].host)
and not core.utils.parse_ipv6(v[1].host)) then
route_up_flags["has_domain"] = true
end
elseif k == "pass_host" and v ~= "pass" then
route_up_flags["pass_host"] = true
elseif k == "scheme" and v ~= "http" then
route_up_flags["scheme"] = true
elseif k == "checks" then
route_up_flags["checks"] = true
elseif k == "retries" then
route_up_flags["retries"] = true
elseif k == "timeout" then
route_up_flags["timeout"] = true
elseif k == "tls" then
route_up_flags["tls"] = true
elseif k == "keepalive_pool" then
route_up_flags["keepalive_pool"] = true
elseif k == "service_name" then
route_up_flags["service_name"] = true
end
end
else
route_up_flags["more_nodes"] = true
end
end
end
end
end
local global_rules, _ = get_global_rules()
local global_rules_flag = global_rules and #global_rules ~= 0
if route_flags["vars"] or route_flags["filter_func"]
or route_flags["remote_addr"]
or route_flags["service_id"]
or route_flags["plugin_config_id"]
or global_rules_flag then
router.router_http.matching = orig_router_http_matching
else
core.log.info("use ai plane to match route")
router.router_http.matching = ai_router_http_matching
local count = #routes + 3000
core.log.info("renew route cache: count=", count)
route_lrucache = core.lrucache.new({
count = count
})
local ok, err = gen_get_cache_key_func(route_flags)
if not ok then
core.log.error("generate get_cache_key_func failed:", err)
router.router_http.matching = orig_router_http_matching
end
end
if route_flags["service"]
or route_flags["script"]
or route_flags["service_id"]
or route_flags["upstream_id"]
or route_flags["enable_websocket"]
or route_flags["plugins"]
or route_flags["plugin_config_id"]
or route_up_flags["has_domain"]
or route_up_flags["pass_host"]
or route_up_flags["scheme"]
or route_up_flags["checks"]
or route_up_flags["retries"]
or route_up_flags["timeout"]
or route_up_flags["tls"]
or route_up_flags["keepalive_pool"]
or route_up_flags["service_name"]
or route_up_flags["more_nodes"]
or global_rules_flag then
apisix.handle_upstream = orig_handle_upstream
apisix.http_balancer_phase = orig_http_balancer_phase
else
-- replace the upstream and balancer module
apisix.handle_upstream = ai_upstream
apisix.http_balancer_phase = ai_http_balancer_phase
end
end
function _M.init()
event.register(event.CONST.BUILD_ROUTER, routes_analyze)
local local_conf = core.config.local_conf()
local up_keepalive_conf =
core.table.try_read_attr(local_conf, "nginx_config",
"http", "upstream")
default_keepalive_pool.idle_timeout =
core.config_util.parse_time_unit(up_keepalive_conf.keepalive_timeout)
default_keepalive_pool.size = up_keepalive_conf.keepalive
default_keepalive_pool.requests = up_keepalive_conf.keepalive_requests
pool_opt = { pool_size = default_keepalive_pool.size }
end
function _M.destroy()
if orig_router_http_matching then
router.router_http.matching = orig_router_http_matching
orig_router_http_matching = nil
end
if orig_handle_upstream then
apisix.handle_upstream = orig_handle_upstream
orig_handle_upstream = nil
end
if orig_http_balancer_phase then
apisix.http_balancer_phase = orig_http_balancer_phase
orig_http_balancer_phase = nil
end
event.unregister(event.CONST.BUILD_ROUTER)
end
return _M

View File

@@ -0,0 +1,267 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local plugin_name = "api-breaker"
local ngx = ngx
local math = math
local error = error
local ipairs = ipairs
local shared_buffer = ngx.shared["plugin-".. plugin_name]
if not shared_buffer then
error("failed to get ngx.shared dict when load plugin " .. plugin_name)
end
local schema = {
type = "object",
properties = {
break_response_code = {
type = "integer",
minimum = 200,
maximum = 599,
},
break_response_body = {
type = "string"
},
break_response_headers = {
type = "array",
items = {
type = "object",
properties = {
key = {
type = "string",
minLength = 1
},
value = {
type = "string",
minLength = 1
}
},
required = {"key", "value"},
}
},
max_breaker_sec = {
type = "integer",
minimum = 3,
default = 300,
},
unhealthy = {
type = "object",
properties = {
http_statuses = {
type = "array",
minItems = 1,
items = {
type = "integer",
minimum = 500,
maximum = 599,
},
uniqueItems = true,
default = {500}
},
failures = {
type = "integer",
minimum = 1,
default = 3,
}
},
default = {http_statuses = {500}, failures = 3}
},
healthy = {
type = "object",
properties = {
http_statuses = {
type = "array",
minItems = 1,
items = {
type = "integer",
minimum = 200,
maximum = 499,
},
uniqueItems = true,
default = {200}
},
successes = {
type = "integer",
minimum = 1,
default = 3,
}
},
default = {http_statuses = {200}, successes = 3}
}
},
required = {"break_response_code"},
}
local function gen_healthy_key(ctx)
return "healthy-" .. core.request.get_host(ctx) .. ctx.var.uri
end
local function gen_unhealthy_key(ctx)
return "unhealthy-" .. core.request.get_host(ctx) .. ctx.var.uri
end
local function gen_lasttime_key(ctx)
return "unhealthy-lasttime" .. core.request.get_host(ctx) .. ctx.var.uri
end
local _M = {
version = 0.1,
name = plugin_name,
priority = 1005,
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.access(conf, ctx)
local unhealthy_key = gen_unhealthy_key(ctx)
-- unhealthy counts
local unhealthy_count, err = shared_buffer:get(unhealthy_key)
if err then
core.log.warn("failed to get unhealthy_key: ",
unhealthy_key, " err: ", err)
return
end
if not unhealthy_count then
return
end
-- timestamp of the last time a unhealthy state was triggered
local lasttime_key = gen_lasttime_key(ctx)
local lasttime, err = shared_buffer:get(lasttime_key)
if err then
core.log.warn("failed to get lasttime_key: ",
lasttime_key, " err: ", err)
return
end
if not lasttime then
return
end
local failure_times = math.ceil(unhealthy_count / conf.unhealthy.failures)
if failure_times < 1 then
failure_times = 1
end
-- cannot exceed the maximum value of the user configuration
local breaker_time = 2 ^ failure_times
if breaker_time > conf.max_breaker_sec then
breaker_time = conf.max_breaker_sec
end
core.log.info("breaker_time: ", breaker_time)
-- breaker
if lasttime + breaker_time >= ngx.time() then
if conf.break_response_body then
if conf.break_response_headers then
for _, value in ipairs(conf.break_response_headers) do
local val = core.utils.resolve_var(value.value, ctx.var)
core.response.add_header(value.key, val)
end
end
return conf.break_response_code, conf.break_response_body
end
return conf.break_response_code
end
return
end
function _M.log(conf, ctx)
local unhealthy_key = gen_unhealthy_key(ctx)
local healthy_key = gen_healthy_key(ctx)
local upstream_status = core.response.get_upstream_status(ctx)
if not upstream_status then
return
end
-- unhealthy process
if core.table.array_find(conf.unhealthy.http_statuses,
upstream_status)
then
local unhealthy_count, err = shared_buffer:incr(unhealthy_key, 1, 0)
if err then
core.log.warn("failed to incr unhealthy_key: ", unhealthy_key,
" err: ", err)
end
core.log.info("unhealthy_key: ", unhealthy_key, " count: ",
unhealthy_count)
shared_buffer:delete(healthy_key)
-- whether the user-configured number of failures has been reached,
-- and if so, the timestamp for entering the unhealthy state.
if unhealthy_count % conf.unhealthy.failures == 0 then
shared_buffer:set(gen_lasttime_key(ctx), ngx.time(),
conf.max_breaker_sec)
core.log.info("update unhealthy_key: ", unhealthy_key, " to ",
unhealthy_count)
end
return
end
-- health process
if not core.table.array_find(conf.healthy.http_statuses, upstream_status) then
return
end
local unhealthy_count, err = shared_buffer:get(unhealthy_key)
if err then
core.log.warn("failed to `get` unhealthy_key: ", unhealthy_key,
" err: ", err)
end
if not unhealthy_count then
return
end
local healthy_count, err = shared_buffer:incr(healthy_key, 1, 0)
if err then
core.log.warn("failed to `incr` healthy_key: ", healthy_key,
" err: ", err)
end
-- clear related status
if healthy_count >= conf.healthy.successes then
-- stat change to normal
core.log.info("change to normal, ", healthy_key, " ", healthy_count)
shared_buffer:delete(gen_lasttime_key(ctx))
shared_buffer:delete(unhealthy_key)
shared_buffer:delete(healthy_key)
end
return
end
return _M

View File

@@ -0,0 +1,68 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local pairs = pairs
local plugin_name = "attach-consumer-label"
local schema = {
type = "object",
properties = {
headers = {
type = "object",
additionalProperties = {
type = "string",
pattern = "^\\$.*"
},
minProperties = 1
},
},
required = {"headers"},
}
local _M = {
version = 0.1,
priority = 2399,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.before_proxy(conf, ctx)
-- check if the consumer is exists in the context
if not ctx.consumer then
return
end
local labels = ctx.consumer.labels
core.log.info("consumer username: ", ctx.consumer.username, " labels: ",
core.json.delay_encode(labels))
if not labels then
return
end
for header, label_key in pairs(conf.headers) do
-- remove leading $ character
local label_value = labels[label_key:sub(2)]
core.request.set_header(ctx, header, label_value)
end
end
return _M

View File

@@ -0,0 +1,135 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local casbin = require("casbin")
local core = require("apisix.core")
local plugin = require("apisix.plugin")
local plugin_name = "authz-casbin"
local schema = {
type = "object",
properties = {
model_path = { type = "string" },
policy_path = { type = "string" },
model = { type = "string" },
policy = { type = "string" },
username = { type = "string"}
},
oneOf = {
{required = {"model_path", "policy_path", "username"}},
{required = {"model", "policy", "username"}}
},
}
local metadata_schema = {
type = "object",
properties = {
model = {type = "string"},
policy = {type = "string"},
},
required = {"model", "policy"},
}
local _M = {
version = 0.1,
priority = 2560,
name = plugin_name,
schema = schema,
metadata_schema = metadata_schema
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local ok, err = core.schema.check(schema, conf)
if ok then
return true
else
local metadata = plugin.plugin_metadata(plugin_name)
if metadata and metadata.value and conf.username then
return true
end
end
return false, err
end
local casbin_enforcer
local function new_enforcer_if_need(conf)
if conf.model_path and conf.policy_path then
local model_path = conf.model_path
local policy_path = conf.policy_path
if not conf.casbin_enforcer then
conf.casbin_enforcer = casbin:new(model_path, policy_path)
end
return true
end
if conf.model and conf.policy then
local model = conf.model
local policy = conf.policy
if not conf.casbin_enforcer then
conf.casbin_enforcer = casbin:newEnforcerFromText(model, policy)
end
return true
end
local metadata = plugin.plugin_metadata(plugin_name)
if not (metadata and metadata.value.model and metadata.value.policy) then
return nil, "not enough configuration to create enforcer"
end
local modifiedIndex = metadata.modifiedIndex
if not casbin_enforcer or casbin_enforcer.modifiedIndex ~= modifiedIndex then
local model = metadata.value.model
local policy = metadata.value.policy
casbin_enforcer = casbin:newEnforcerFromText(model, policy)
casbin_enforcer.modifiedIndex = modifiedIndex
end
return true
end
function _M.rewrite(conf, ctx)
-- creates an enforcer when request sent for the first time
local ok, err = new_enforcer_if_need(conf)
if not ok then
core.log.error(err)
return 503
end
local path = ctx.var.uri
local method = ctx.var.method
local headers = core.request.headers(ctx)
local username = headers[conf.username] or "anonymous"
if conf.casbin_enforcer then
if not conf.casbin_enforcer:enforce(username, path, method) then
return 403, {message = "Access Denied"}
end
else
if not casbin_enforcer:enforce(username, path, method) then
return 403, {message = "Access Denied"}
end
end
end
return _M

View File

@@ -0,0 +1,176 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local http = require("resty.http")
local session = require("resty.session")
local ngx = ngx
local rand = math.random
local tostring = tostring
local plugin_name = "authz-casdoor"
local schema = {
type = "object",
properties = {
-- Note: endpoint_addr and callback_url should not end with '/'
endpoint_addr = {type = "string", pattern = "^[^%?]+[^/]$"},
client_id = {type = "string"},
client_secret = {type = "string"},
callback_url = {type = "string", pattern = "^[^%?]+[^/]$"}
},
encrypt_fields = {"client_secret"},
required = {
"callback_url", "endpoint_addr", "client_id", "client_secret"
}
}
local _M = {
version = 0.1,
priority = 2559,
name = plugin_name,
schema = schema
}
local function fetch_access_token(code, conf)
local client = http.new()
local url = conf.endpoint_addr .. "/api/login/oauth/access_token"
local res, err = client:request_uri(url, {
method = "POST",
body = ngx.encode_args({
code = code,
grant_type = "authorization_code",
client_id = conf.client_id,
client_secret = conf.client_secret
}),
headers = {
["Content-Type"] = "application/x-www-form-urlencoded"
}
})
if not res then
return nil, nil, err
end
local data, err = core.json.decode(res.body)
if err or not data then
err = "failed to parse casdoor response data: " .. err .. ", body: " .. res.body
return nil, nil, err
end
if not data.access_token then
return nil, nil,
"failed when accessing token: no access_token contained"
end
-- In the reply of casdoor, setting expires_in to 0 indicates that the access_token is invalid.
if not data.expires_in or data.expires_in == 0 then
return nil, nil, "failed when accessing token: invalid access_token"
end
return data.access_token, data.expires_in, nil
end
function _M.check_schema(conf)
local check = {"endpoint_addr", "callback_url"}
core.utils.check_https(check, conf, plugin_name)
return core.schema.check(schema, conf)
end
function _M.access(conf, ctx)
local current_uri = ctx.var.uri
local session_obj_read, session_present = session.open()
-- step 1: check whether hits the callback
local m, err = ngx.re.match(conf.callback_url, ".+//[^/]+(/.*)", "jo")
if err or not m then
core.log.error(err)
return 503
end
local real_callback_url = m[1]
if current_uri == real_callback_url then
if not session_present then
err = "no session found"
core.log.error(err)
return 503
end
local state_in_session = session_obj_read.data.state
if not state_in_session then
err = "no state found in session"
core.log.error(err)
return 503
end
local args = core.request.get_uri_args(ctx)
if not args or not args.code or not args.state then
err = "failed when accessing token. Invalid code or state"
core.log.error(err)
return 400, err
end
if args.state ~= tostring(state_in_session) then
err = "invalid state"
core.log.error(err)
return 400, err
end
if not args.code then
err = "invalid code"
core.log.error(err)
return 400, err
end
local access_token, lifetime, err =
fetch_access_token(args.code, conf)
if not access_token then
core.log.error(err)
return 503
end
local original_url = session_obj_read.data.original_uri
if not original_url then
err = "no original_url found in session"
core.log.error(err)
return 503
end
local session_obj_write = session.new {
cookie = {lifetime = lifetime}
}
session_obj_write:start()
session_obj_write.data.access_token = access_token
session_obj_write:save()
core.response.set_header("Location", original_url)
return 302
end
-- step 2: check whether session exists
if not (session_present and session_obj_read.data.access_token) then
-- session not exists, redirect to login page
local state = rand(0x7fffffff)
local session_obj_write = session.start()
session_obj_write.data.original_uri = current_uri
session_obj_write.data.state = state
session_obj_write:save()
local redirect_url = conf.endpoint_addr .. "/login/oauth/authorize?" .. ngx.encode_args({
response_type = "code",
scope = "read",
state = state,
client_id = conf.client_id,
redirect_uri = conf.callback_url
})
core.response.set_header("Location", redirect_url)
return 302
end
end
return _M

View File

@@ -0,0 +1,790 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local http = require "resty.http"
local sub_str = string.sub
local type = type
local ngx = ngx
local plugin_name = "authz-keycloak"
local fetch_secrets = require("apisix.secret").fetch_secrets
local log = core.log
local pairs = pairs
local schema = {
type = "object",
properties = {
discovery = {type = "string", minLength = 1, maxLength = 4096},
token_endpoint = {type = "string", minLength = 1, maxLength = 4096},
resource_registration_endpoint = {type = "string", minLength = 1, maxLength = 4096},
client_id = {type = "string", minLength = 1, maxLength = 100},
client_secret = {type = "string", minLength = 1, maxLength = 100},
grant_type = {
type = "string",
default="urn:ietf:params:oauth:grant-type:uma-ticket",
enum = {"urn:ietf:params:oauth:grant-type:uma-ticket"},
minLength = 1, maxLength = 100
},
policy_enforcement_mode = {
type = "string",
enum = {"ENFORCING", "PERMISSIVE"},
default = "ENFORCING"
},
permissions = {
type = "array",
items = {
type = "string",
minLength = 1, maxLength = 100
},
uniqueItems = true,
default = {}
},
lazy_load_paths = {type = "boolean", default = false},
http_method_as_scope = {type = "boolean", default = false},
timeout = {type = "integer", minimum = 1000, default = 3000},
ssl_verify = {type = "boolean", default = true},
cache_ttl_seconds = {type = "integer", minimum = 1, default = 24 * 60 * 60},
keepalive = {type = "boolean", default = true},
keepalive_timeout = {type = "integer", minimum = 1000, default = 60000},
keepalive_pool = {type = "integer", minimum = 1, default = 5},
access_denied_redirect_uri = {type = "string", minLength = 1, maxLength = 2048},
access_token_expires_in = {type = "integer", minimum = 1, default = 300},
access_token_expires_leeway = {type = "integer", minimum = 0, default = 0},
refresh_token_expires_in = {type = "integer", minimum = 1, default = 3600},
refresh_token_expires_leeway = {type = "integer", minimum = 0, default = 0},
password_grant_token_generation_incoming_uri = {
type = "string",
minLength = 1,
maxLength = 4096
},
},
encrypt_fields = {"client_secret"},
required = {"client_id"},
allOf = {
-- Require discovery or token endpoint.
{
anyOf = {
{required = {"discovery"}},
{required = {"token_endpoint"}}
}
},
-- If lazy_load_paths is true, require discovery or resource registration endpoint.
{
anyOf = {
{
properties = {
lazy_load_paths = {enum = {false}},
}
},
{
properties = {
lazy_load_paths = {enum = {true}},
},
anyOf = {
{required = {"discovery"}},
{required = {"resource_registration_endpoint"}}
}
}
}
}
}
}
local _M = {
version = 0.1,
priority = 2000,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
local check = {"discovery", "token_endpoint", "resource_registration_endpoint",
"access_denied_redirect_uri"}
core.utils.check_https(check, conf, plugin_name)
core.utils.check_tls_bool({"ssl_verify"}, conf, plugin_name)
return core.schema.check(schema, conf)
end
-- Some auxiliary functions below heavily inspired by the excellent
-- lua-resty-openidc module; see https://github.com/zmartzone/lua-resty-openidc
-- Retrieve value from server-wide cache, if available.
local function authz_keycloak_cache_get(type, key)
local dict = ngx.shared[type]
local value
if dict then
value = dict:get(key)
if value then log.debug("cache hit: type=", type, " key=", key) end
end
return value
end
-- Set value in server-wide cache, if available.
local function authz_keycloak_cache_set(type, key, value, exp)
local dict = ngx.shared[type]
if dict and (exp > 0) then
local success, err, forcible = dict:set(key, value, exp)
if err then
log.error("cache set: success=", success, " err=", err, " forcible=", forcible)
else
log.debug("cache set: success=", success, " err=", err, " forcible=", forcible)
end
end
end
-- Configure request parameters.
local function authz_keycloak_configure_params(params, conf)
-- Keepalive options.
if conf.keepalive then
params.keepalive_timeout = conf.keepalive_timeout
params.keepalive_pool = conf.keepalive_pool
else
params.keepalive = conf.keepalive
end
-- TLS verification.
params.ssl_verify = conf.ssl_verify
-- Decorate parameters, maybe, and return.
return conf.http_request_decorator and conf.http_request_decorator(params) or params
end
-- Configure timeouts.
local function authz_keycloak_configure_timeouts(httpc, timeout)
if timeout then
if type(timeout) == "table" then
httpc:set_timeouts(timeout.connect or 0, timeout.send or 0, timeout.read or 0)
else
httpc:set_timeout(timeout)
end
end
end
-- Set outgoing proxy options.
local function authz_keycloak_configure_proxy(httpc, proxy_opts)
if httpc and proxy_opts and type(proxy_opts) == "table" then
log.debug("authz_keycloak_configure_proxy : use http proxy")
httpc:set_proxy_options(proxy_opts)
else
log.debug("authz_keycloak_configure_proxy : don't use http proxy")
end
end
-- Get and configure HTTP client.
local function authz_keycloak_get_http_client(conf)
local httpc = http.new()
authz_keycloak_configure_timeouts(httpc, conf.timeout)
authz_keycloak_configure_proxy(httpc, conf.proxy_opts)
return httpc
end
-- Parse the JSON result from a call to the OP.
local function authz_keycloak_parse_json_response(response)
local err
local res
-- Check the response from the OP.
if response.status ~= 200 then
err = "response indicates failure, status=" .. response.status .. ", body=" .. response.body
else
-- Decode the response and extract the JSON object.
res, err = core.json.decode(response.body)
if not res then
err = "JSON decoding failed: " .. err
end
end
return res, err
end
-- Get the Discovery metadata from the specified URL.
local function authz_keycloak_discover(conf)
log.debug("authz_keycloak_discover: URL is: " .. conf.discovery)
local json, err
local v = authz_keycloak_cache_get("discovery", conf.discovery)
if not v then
log.debug("Discovery data not in cache, making call to discovery endpoint.")
-- Make the call to the discovery endpoint.
local httpc = authz_keycloak_get_http_client(conf)
local params = authz_keycloak_configure_params({}, conf)
local res, error = httpc:request_uri(conf.discovery, params)
if not res then
err = "Accessing discovery URL (" .. conf.discovery .. ") failed: " .. error
log.error(err)
else
log.debug("Response data: " .. res.body)
json, err = authz_keycloak_parse_json_response(res)
if json then
authz_keycloak_cache_set("discovery", conf.discovery, core.json.encode(json),
conf.cache_ttl_seconds)
else
err = "could not decode JSON from Discovery data" .. (err and (": " .. err) or '')
log.error(err)
end
end
else
json = core.json.decode(v)
end
return json, err
end
-- Turn a discovery url set in the conf dictionary into the discovered information.
local function authz_keycloak_ensure_discovered_data(conf)
local err
if type(conf.discovery) == "string" then
local discovery
discovery, err = authz_keycloak_discover(conf)
if not err then
conf.discovery = discovery
end
end
return err
end
-- Get an endpoint from the configuration.
local function authz_keycloak_get_endpoint(conf, endpoint)
if conf and conf[endpoint] then
-- Use explicit entry.
return conf[endpoint]
elseif conf and conf.discovery and type(conf.discovery) == "table" then
-- Use discovery data.
return conf.discovery[endpoint]
end
-- Unable to obtain endpoint.
return nil
end
-- Return the token endpoint from the configuration.
local function authz_keycloak_get_token_endpoint(conf)
return authz_keycloak_get_endpoint(conf, "token_endpoint")
end
-- Return the resource registration endpoint from the configuration.
local function authz_keycloak_get_resource_registration_endpoint(conf)
return authz_keycloak_get_endpoint(conf, "resource_registration_endpoint")
end
-- Return access_token expires_in value (in seconds).
local function authz_keycloak_access_token_expires_in(conf, expires_in)
return (expires_in or conf.access_token_expires_in)
- 1 - conf.access_token_expires_leeway
end
-- Return refresh_token expires_in value (in seconds).
local function authz_keycloak_refresh_token_expires_in(conf, expires_in)
return (expires_in or conf.refresh_token_expires_in)
- 1 - conf.refresh_token_expires_leeway
end
-- Ensure a valid service account access token is available for the configured client.
local function authz_keycloak_ensure_sa_access_token(conf)
local client_id = conf.client_id
local ttl = conf.cache_ttl_seconds
local token_endpoint = authz_keycloak_get_token_endpoint(conf)
if not token_endpoint then
log.error("Unable to determine token endpoint.")
return 503, "Unable to determine token endpoint."
end
local session = authz_keycloak_cache_get("access-tokens", token_endpoint .. ":"
.. client_id)
if session then
-- Decode session string.
local err
session, err = core.json.decode(session)
if not session then
-- Should never happen.
return 500, err
end
local current_time = ngx.time()
if current_time < session.access_token_expiration then
-- Access token is still valid.
log.debug("Access token is still valid.")
return session.access_token
else
-- Access token has expired.
log.debug("Access token has expired.")
if session.refresh_token
and (not session.refresh_token_expiration
or current_time < session.refresh_token_expiration) then
-- Try to get a new access token, using the refresh token.
log.debug("Trying to get new access token using refresh token.")
local httpc = authz_keycloak_get_http_client(conf)
local params = {
method = "POST",
body = ngx.encode_args({
grant_type = "refresh_token",
client_id = client_id,
client_secret = conf.client_secret,
refresh_token = session.refresh_token,
}),
headers = {
["Content-Type"] = "application/x-www-form-urlencoded"
}
}
params = authz_keycloak_configure_params(params, conf)
local res, err = httpc:request_uri(token_endpoint, params)
if not res then
err = "Accessing token endpoint URL (" .. token_endpoint
.. ") failed: " .. err
log.error(err)
return nil, err
end
log.debug("Response data: " .. res.body)
local json, err = authz_keycloak_parse_json_response(res)
if not json then
err = "Could not decode JSON from token endpoint"
.. (err and (": " .. err) or '.')
log.error(err)
return nil, err
end
if not json.access_token then
-- Clear session.
log.debug("Answer didn't contain a new access token. Clearing session.")
session = nil
else
log.debug("Got new access token.")
-- Save access token.
session.access_token = json.access_token
-- Calculate and save access token expiry time.
session.access_token_expiration = current_time
+ authz_keycloak_access_token_expires_in(conf, json.expires_in)
-- Save refresh token, maybe.
if json.refresh_token ~= nil then
log.debug("Got new refresh token.")
session.refresh_token = json.refresh_token
-- Calculate and save refresh token expiry time.
session.refresh_token_expiration = current_time
+ authz_keycloak_refresh_token_expires_in(conf,
json.refresh_expires_in)
end
authz_keycloak_cache_set("access-tokens",
token_endpoint .. ":" .. client_id,
core.json.encode(session), ttl)
end
else
-- No refresh token available, or it has expired. Clear session.
log.debug("No or expired refresh token. Clearing session.")
session = nil
end
end
end
if not session then
-- No session available. Create a new one.
log.debug("Getting access token for Protection API from token endpoint.")
local httpc = authz_keycloak_get_http_client(conf)
local params = {
method = "POST",
body = ngx.encode_args({
grant_type = "client_credentials",
client_id = client_id,
client_secret = conf.client_secret,
}),
headers = {
["Content-Type"] = "application/x-www-form-urlencoded"
}
}
params = authz_keycloak_configure_params(params, conf)
local current_time = ngx.time()
local res, err = httpc:request_uri(token_endpoint, params)
if not res then
err = "Accessing token endpoint URL (" .. token_endpoint .. ") failed: " .. err
log.error(err)
return nil, err
end
log.debug("Response data: " .. res.body)
local json, err = authz_keycloak_parse_json_response(res)
if not json then
err = "Could not decode JSON from token endpoint" .. (err and (": " .. err) or '.')
log.error(err)
return nil, err
end
if not json.access_token then
err = "Response does not contain access_token field."
log.error(err)
return nil, err
end
session = {}
-- Save access token.
session.access_token = json.access_token
-- Calculate and save access token expiry time.
session.access_token_expiration = current_time
+ authz_keycloak_access_token_expires_in(conf, json.expires_in)
-- Save refresh token, maybe.
if json.refresh_token ~= nil then
session.refresh_token = json.refresh_token
-- Calculate and save refresh token expiry time.
session.refresh_token_expiration = current_time
+ authz_keycloak_refresh_token_expires_in(conf, json.refresh_expires_in)
end
authz_keycloak_cache_set("access-tokens", token_endpoint .. ":" .. client_id,
core.json.encode(session), ttl)
end
return session.access_token
end
-- Resolve a URI to one or more resource IDs.
local function authz_keycloak_resolve_resource(conf, uri, sa_access_token)
-- Get resource registration endpoint URL.
local resource_registration_endpoint = authz_keycloak_get_resource_registration_endpoint(conf)
if not resource_registration_endpoint then
local err = "Unable to determine registration endpoint."
log.error(err)
return nil, err
end
log.debug("Resource registration endpoint: ", resource_registration_endpoint)
local httpc = authz_keycloak_get_http_client(conf)
local params = {
method = "GET",
query = {uri = uri, matchingUri = "true"},
headers = {
["Authorization"] = "Bearer " .. sa_access_token
}
}
params = authz_keycloak_configure_params(params, conf)
local res, err = httpc:request_uri(resource_registration_endpoint, params)
if not res then
err = "Accessing resource registration endpoint URL (" .. resource_registration_endpoint
.. ") failed: " .. err
log.error(err)
return nil, err
end
log.debug("Response data: " .. res.body)
res.body = '{"resources": ' .. res.body .. '}'
local json, err = authz_keycloak_parse_json_response(res)
if not json then
err = "Could not decode JSON from resource registration endpoint"
.. (err and (": " .. err) or '.')
log.error(err)
return nil, err
end
return json.resources
end
local function evaluate_permissions(conf, ctx, token)
-- Ensure discovered data.
local err = authz_keycloak_ensure_discovered_data(conf)
if err then
return 503, err
end
local permission
if conf.lazy_load_paths then
-- Ensure service account access token.
local sa_access_token, err = authz_keycloak_ensure_sa_access_token(conf)
if err then
log.error(err)
return 503, err
end
-- Resolve URI to resource(s).
permission, err = authz_keycloak_resolve_resource(conf, ctx.var.request_uri,
sa_access_token)
-- Check result.
if permission == nil then
-- No result back from resource registration endpoint.
log.error(err)
return 503, err
end
else
-- Use statically configured permissions.
permission = conf.permissions
end
-- Return 403 or 307 if permission is empty and enforcement mode is "ENFORCING".
if #permission == 0 and conf.policy_enforcement_mode == "ENFORCING" then
-- Return Keycloak-style message for consistency.
if conf.access_denied_redirect_uri then
core.response.set_header("Location", conf.access_denied_redirect_uri)
return 307
end
return 403, '{"error":"access_denied","error_description":"not_authorized"}'
end
-- Determine scope from HTTP method, maybe.
local scope
if conf.http_method_as_scope then
scope = ctx.var.request_method
end
if scope then
-- Loop over permissions and add scope.
for k, v in pairs(permission) do
if v:find("#", 1, true) then
-- Already contains scope.
permission[k] = v .. ", " .. scope
else
-- Doesn't contain scope yet.
permission[k] = v .. "#" .. scope
end
end
end
for k, v in pairs(permission) do
log.debug("Requesting permission ", v, ".")
end
-- Get token endpoint URL.
local token_endpoint = authz_keycloak_get_token_endpoint(conf)
if not token_endpoint then
err = "Unable to determine token endpoint."
log.error(err)
return 503, err
end
log.debug("Token endpoint: ", token_endpoint)
local httpc = authz_keycloak_get_http_client(conf)
local params = {
method = "POST",
body = ngx.encode_args({
grant_type = conf.grant_type,
audience = conf.client_id,
response_mode = "decision",
permission = permission
}),
headers = {
["Content-Type"] = "application/x-www-form-urlencoded",
["Authorization"] = token
}
}
params = authz_keycloak_configure_params(params, conf)
local res, err = httpc:request_uri(token_endpoint, params)
if not res then
err = "Error while sending authz request to " .. token_endpoint .. ": " .. err
log.error(err)
return 503
end
log.debug("Response status: ", res.status, ", data: ", res.body)
if res.status == 403 then
-- Request permanently denied, e.g. due to lacking permissions.
log.debug('Request denied: HTTP 403 Forbidden. Body: ', res.body)
if conf.access_denied_redirect_uri then
core.response.set_header("Location", conf.access_denied_redirect_uri)
return 307
end
return res.status, res.body
elseif res.status == 401 then
-- Request temporarily denied, e.g access token not valid.
log.debug('Request denied: HTTP 401 Unauthorized. Body: ', res.body)
return res.status, res.body
elseif res.status >= 400 then
-- Some other error. Log full response.
log.error('Request denied: Token endpoint returned an error (status: ',
res.status, ', body: ', res.body, ').')
return res.status, res.body
end
-- Request accepted.
end
local function fetch_jwt_token(ctx)
local token = core.request.header(ctx, "Authorization")
if not token then
return nil, "authorization header not available"
end
local prefix = sub_str(token, 1, 7)
if prefix ~= 'Bearer ' and prefix ~= 'bearer ' then
return "Bearer " .. token
end
return token
end
-- To get new access token by calling get token api
local function generate_token_using_password_grant(conf,ctx)
log.debug("generate_token_using_password_grant Function Called")
local body, err = core.request.get_body()
if err or not body then
log.error("Failed to get request body: ", err)
return 503
end
local parameters = core.string.decode_args(body)
local username = parameters["username"]
local password = parameters["password"]
if not username then
local err = "username is missing."
log.warn(err)
return 422, {message = err}
end
if not password then
local err = "password is missing."
log.warn(err)
return 422, {message = err}
end
local client_id = conf.client_id
local token_endpoint = authz_keycloak_get_token_endpoint(conf)
if not token_endpoint then
local err = "Unable to determine token endpoint."
log.error(err)
return 503, {message = err}
end
local httpc = authz_keycloak_get_http_client(conf)
local params = {
method = "POST",
body = ngx.encode_args({
grant_type = "password",
client_id = client_id,
client_secret = conf.client_secret,
username = username,
password = password
}),
headers = {
["Content-Type"] = "application/x-www-form-urlencoded"
}
}
params = authz_keycloak_configure_params(params, conf)
local res, err = httpc:request_uri(token_endpoint, params)
if not res then
err = "Accessing token endpoint URL (" .. token_endpoint
.. ") failed: " .. err
log.error(err)
return 401, {message = "Accessing token endpoint URL failed."}
end
log.debug("Response data: " .. res.body)
local json, err = authz_keycloak_parse_json_response(res)
if not json then
err = "Could not decode JSON from response"
.. (err and (": " .. err) or '.')
log.error(err)
return 401, {message = "Could not decode JSON from response."}
end
return res.status, res.body
end
function _M.access(conf, ctx)
-- resolve secrets
conf = fetch_secrets(conf, true, conf, "")
local headers = core.request.headers(ctx)
local need_grant_token = conf.password_grant_token_generation_incoming_uri and
ctx.var.request_uri == conf.password_grant_token_generation_incoming_uri and
headers["content-type"] == "application/x-www-form-urlencoded" and
core.request.get_method() == "POST"
if need_grant_token then
return generate_token_using_password_grant(conf,ctx)
end
log.debug("hit keycloak-auth access")
local jwt_token, err = fetch_jwt_token(ctx)
if not jwt_token then
log.error("failed to fetch JWT token: ", err)
return 401, {message = "Missing JWT token in request"}
end
local status, body = evaluate_permissions(conf, ctx, jwt_token)
if status then
return status, body
end
end
return _M

View File

@@ -0,0 +1,187 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
local ngx = ngx
local hmac = require("resty.hmac")
local hex_encode = require("resty.string").to_hex
local resty_sha256 = require("resty.sha256")
local str_strip = require("pl.stringx").strip
local norm_path = require("pl.path").normpath
local pairs = pairs
local tab_concat = table.concat
local tab_sort = table.sort
local os = os
local plugin_name = "aws-lambda"
local plugin_version = 0.1
local priority = -1899
local ALGO = "AWS4-HMAC-SHA256"
local function hmac256(key, msg)
return hmac:new(key, hmac.ALGOS.SHA256):final(msg)
end
local function sha256(msg)
local hash = resty_sha256:new()
hash:update(msg)
local digest = hash:final()
return hex_encode(digest)
end
local function get_signature_key(key, datestamp, region, service)
local kDate = hmac256("AWS4" .. key, datestamp)
local kRegion = hmac256(kDate, region)
local kService = hmac256(kRegion, service)
local kSigning = hmac256(kService, "aws4_request")
return kSigning
end
local aws_authz_schema = {
type = "object",
properties = {
-- API Key based authorization
apikey = {type = "string"},
-- IAM role based authorization, works via aws v4 request signing
-- more at https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html
iam = {
type = "object",
properties = {
accesskey = {
type = "string",
description = "access key id from from aws iam console"
},
secretkey = {
type = "string",
description = "secret access key from from aws iam console"
},
aws_region = {
type = "string",
default = "us-east-1",
description = "the aws region that is receiving the request"
},
service = {
type = "string",
default = "execute-api",
description = "the service that is receiving the request"
}
},
required = {"accesskey", "secretkey"}
}
}
}
local function request_processor(conf, ctx, params)
local headers = params.headers
-- set authorization headers if not already set by the client
-- we are following not to overwrite the authz keys
if not headers["x-api-key"] then
if conf.authorization and conf.authorization.apikey then
headers["x-api-key"] = conf.authorization.apikey
return
end
end
-- performing aws v4 request signing for IAM authorization
-- visit https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html
-- to look at the pseudocode in python.
if headers["authorization"] or not conf.authorization or not conf.authorization.iam then
return
end
-- create a date for headers and the credential string
local t = ngx.time()
local amzdate = os.date("!%Y%m%dT%H%M%SZ", t)
local datestamp = os.date("!%Y%m%d", t) -- Date w/o time, used in credential scope
headers["X-Amz-Date"] = amzdate
-- computing canonical uri
local canonical_uri = norm_path(params.path)
if canonical_uri ~= "/" then
if canonical_uri:sub(-1, -1) == "/" then
canonical_uri = canonical_uri:sub(1, -2)
end
if canonical_uri:sub(1, 1) ~= "/" then
canonical_uri = "/" .. canonical_uri
end
end
-- computing canonical query string
local canonical_qs = {}
local canonical_qs_i = 0
for k, v in pairs(params.query) do
canonical_qs_i = canonical_qs_i + 1
canonical_qs[canonical_qs_i] = ngx.unescape_uri(k) .. "=" .. ngx.unescape_uri(v)
end
tab_sort(canonical_qs)
canonical_qs = tab_concat(canonical_qs, "&")
-- computing canonical and signed headers
local canonical_headers, signed_headers = {}, {}
local signed_headers_i = 0
for k, v in pairs(headers) do
k = k:lower()
if k ~= "connection" then
signed_headers_i = signed_headers_i + 1
signed_headers[signed_headers_i] = k
-- strip starting and trailing spaces including strip multiple spaces into single space
canonical_headers[k] = str_strip(v)
end
end
tab_sort(signed_headers)
for i = 1, #signed_headers do
local k = signed_headers[i]
canonical_headers[i] = k .. ":" .. canonical_headers[k] .. "\n"
end
canonical_headers = tab_concat(canonical_headers, nil, 1, #signed_headers)
signed_headers = tab_concat(signed_headers, ";")
-- combining elements to form the canonical request (step-1)
local canonical_request = params.method:upper() .. "\n"
.. canonical_uri .. "\n"
.. (canonical_qs or "") .. "\n"
.. canonical_headers .. "\n"
.. signed_headers .. "\n"
.. sha256(params.body or "")
-- creating the string to sign for aws signature v4 (step-2)
local iam = conf.authorization.iam
local credential_scope = datestamp .. "/" .. iam.aws_region .. "/"
.. iam.service .. "/aws4_request"
local string_to_sign = ALGO .. "\n"
.. amzdate .. "\n"
.. credential_scope .. "\n"
.. sha256(canonical_request)
-- calculate the signature (step-3)
local signature_key = get_signature_key(iam.secretkey, datestamp, iam.aws_region, iam.service)
local signature = hex_encode(hmac256(signature_key, string_to_sign))
-- add info to the headers (step-4)
headers["authorization"] = ALGO .. " Credential=" .. iam.accesskey
.. "/" .. credential_scope
.. ", SignedHeaders=" .. signed_headers
.. ", Signature=" .. signature
end
local serverless_obj = require("apisix.plugins.serverless.generic-upstream")
return serverless_obj(plugin_name, plugin_version, priority, request_processor, aws_authz_schema)

View File

@@ -0,0 +1,61 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
local plugin = require("apisix.plugin")
local plugin_name, plugin_version, priority = "azure-functions", 0.1, -1900
local azure_authz_schema = {
type = "object",
properties = {
apikey = {type = "string"},
clientid = {type = "string"}
}
}
local metadata_schema = {
type = "object",
properties = {
master_apikey = {type = "string", default = ""},
master_clientid = {type = "string", default = ""}
}
}
local function request_processor(conf, ctx, params)
local headers = params.headers or {}
-- set authorization headers if not already set by the client
-- we are following not to overwrite the authz keys
if not headers["x-functions-key"] and
not headers["x-functions-clientid"] then
if conf.authorization then
headers["x-functions-key"] = conf.authorization.apikey
headers["x-functions-clientid"] = conf.authorization.clientid
else
-- If neither api keys are set with the client request nor inside the plugin attributes
-- plugin will fallback to the master key (if any) present inside the metadata.
local metadata = plugin.plugin_metadata(plugin_name)
if metadata then
headers["x-functions-key"] = metadata.value.master_apikey
headers["x-functions-clientid"] = metadata.value.master_clientid
end
end
end
params.headers = headers
end
return require("apisix.plugins.serverless.generic-upstream")(plugin_name,
plugin_version, priority, request_processor, azure_authz_schema, metadata_schema)

View File

@@ -0,0 +1,189 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ngx = ngx
local ngx_re = require("ngx.re")
local consumer = require("apisix.consumer")
local schema_def = require("apisix.schema_def")
local auth_utils = require("apisix.utils.auth")
local lrucache = core.lrucache.new({
ttl = 300, count = 512
})
local schema = {
type = "object",
title = "work with route or service object",
properties = {
hide_credentials = {
type = "boolean",
default = false,
}
},
anonymous_consumer = schema_def.anonymous_consumer_schema,
}
local consumer_schema = {
type = "object",
title = "work with consumer object",
properties = {
username = { type = "string" },
password = { type = "string" },
},
encrypt_fields = {"password"},
required = {"username", "password"},
}
local plugin_name = "basic-auth"
local _M = {
version = 0.1,
priority = 2520,
type = 'auth',
name = plugin_name,
schema = schema,
consumer_schema = consumer_schema
}
function _M.check_schema(conf, schema_type)
local ok, err
if schema_type == core.schema.TYPE_CONSUMER then
ok, err = core.schema.check(consumer_schema, conf)
else
ok, err = core.schema.check(schema, conf)
end
if not ok then
return false, err
end
return true
end
local function extract_auth_header(authorization)
local function do_extract(auth)
local obj = { username = "", password = "" }
local m, err = ngx.re.match(auth, "Basic\\s(.+)", "jo")
if err then
-- error authorization
return nil, err
end
if not m then
return nil, "Invalid authorization header format"
end
local decoded = ngx.decode_base64(m[1])
if not decoded then
return nil, "Failed to decode authentication header: " .. m[1]
end
local res
res, err = ngx_re.split(decoded, ":")
if err then
return nil, "Split authorization err:" .. err
end
if #res < 2 then
return nil, "Split authorization err: invalid decoded data: " .. decoded
end
obj.username = ngx.re.gsub(res[1], "\\s+", "", "jo")
obj.password = ngx.re.gsub(res[2], "\\s+", "", "jo")
core.log.info("plugin access phase, authorization: ",
obj.username, ": ", obj.password)
return obj, nil
end
local matcher, err = lrucache(authorization, nil, do_extract, authorization)
if matcher then
return matcher.username, matcher.password, err
else
return "", "", err
end
end
local function find_consumer(ctx)
local auth_header = core.request.header(ctx, "Authorization")
if not auth_header then
core.response.set_header("WWW-Authenticate", "Basic realm='.'")
return nil, nil, "Missing authorization in request"
end
local username, password, err = extract_auth_header(auth_header)
if err then
if auth_utils.is_running_under_multi_auth(ctx) then
return nil, nil, err
end
core.log.warn(err)
return nil, nil, "Invalid authorization in request"
end
local cur_consumer, consumer_conf, err = consumer.find_consumer(plugin_name,
"username", username)
if not cur_consumer then
err = "failed to find user: " .. (err or "invalid user")
if auth_utils.is_running_under_multi_auth(ctx) then
return nil, nil, err
end
core.log.warn(err)
return nil, nil, "Invalid user authorization"
end
if cur_consumer.auth_conf.password ~= password then
return nil, nil, "Invalid user authorization"
end
return cur_consumer, consumer_conf, err
end
function _M.rewrite(conf, ctx)
core.log.info("plugin access phase, conf: ", core.json.delay_encode(conf))
local cur_consumer, consumer_conf, err = find_consumer(ctx)
if not cur_consumer then
if not conf.anonymous_consumer then
return 401, { message = err }
end
cur_consumer, consumer_conf, err = consumer.get_anonymous_consumer(conf.anonymous_consumer)
if not cur_consumer then
err = "basic-auth failed to authenticate the request, code: 401. error: " .. err
core.log.error(err)
return 401, { message = "Invalid user authorization" }
end
end
core.log.info("consumer: ", core.json.delay_encode(cur_consumer))
if conf.hide_credentials then
core.request.set_header(ctx, "Authorization", nil)
end
consumer.attach_consumer(ctx, cur_consumer, consumer_conf)
core.log.info("hit basic-auth access")
end
return _M

View File

@@ -0,0 +1,309 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local http = require("resty.http")
local plugin = require("apisix.plugin")
local ngx = ngx
local ipairs = ipairs
local pairs = pairs
local str_find = core.string.find
local str_lower = string.lower
local plugin_name = "batch-requests"
local default_uri = "/apisix/batch-requests"
local attr_schema = {
type = "object",
properties = {
uri = {
type = "string",
description = "uri for batch-requests",
default = default_uri
}
},
}
local schema = {
type = "object",
}
local default_max_body_size = 1024 * 1024 -- 1MiB
local metadata_schema = {
type = "object",
properties = {
max_body_size = {
description = "max pipeline body size in bytes",
type = "integer",
exclusiveMinimum = 0,
default = default_max_body_size,
},
},
}
local method_schema = core.table.clone(core.schema.method_schema)
method_schema.default = "GET"
local req_schema = {
type = "object",
properties = {
query = {
description = "pipeline query string",
type = "object"
},
headers = {
description = "pipeline header",
type = "object"
},
timeout = {
description = "pipeline timeout(ms)",
type = "integer",
default = 30000,
},
pipeline = {
type = "array",
minItems = 1,
items = {
type = "object",
properties = {
version = {
description = "HTTP version",
type = "number",
enum = {1.0, 1.1},
default = 1.1,
},
method = method_schema,
path = {
type = "string",
minLength = 1,
},
query = {
description = "request header",
type = "object",
},
headers = {
description = "request query string",
type = "object",
},
ssl_verify = {
type = "boolean",
default = false
},
}
}
}
},
anyOf = {
{required = {"pipeline"}},
},
}
local _M = {
version = 0.1,
priority = 4010,
name = plugin_name,
schema = schema,
metadata_schema = metadata_schema,
attr_schema = attr_schema,
scope = "global",
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
return core.schema.check(schema, conf)
end
local function check_input(data)
local ok, err = core.schema.check(req_schema, data)
if not ok then
return 400, {error_msg = "bad request body: " .. err}
end
end
local function lowercase_key_or_init(obj)
if not obj then
return {}
end
local lowercase_key_obj = {}
for k, v in pairs(obj) do
lowercase_key_obj[str_lower(k)] = v
end
return lowercase_key_obj
end
local function ensure_header_lowercase(data)
data.headers = lowercase_key_or_init(data.headers)
for i,req in ipairs(data.pipeline) do
req.headers = lowercase_key_or_init(req.headers)
end
end
local function set_common_header(data)
local local_conf = core.config.local_conf()
local real_ip_hdr = core.table.try_read_attr(local_conf, "nginx_config", "http",
"real_ip_header")
-- we don't need to handle '_' to '-' as Nginx won't treat 'X_REAL_IP' as 'X-Real-IP'
real_ip_hdr = str_lower(real_ip_hdr)
local outer_headers = core.request.headers(nil)
for i,req in ipairs(data.pipeline) do
for k, v in pairs(data.headers) do
if not req.headers[k] then
req.headers[k] = v
end
end
if outer_headers then
for k, v in pairs(outer_headers) do
local is_content_header = str_find(k, "content-") == 1
-- skip header start with "content-"
if not req.headers[k] and not is_content_header then
req.headers[k] = v
end
end
end
req.headers[real_ip_hdr] = core.request.get_remote_client_ip()
end
end
local function set_common_query(data)
if not data.query then
return
end
for i,req in ipairs(data.pipeline) do
if not req.query then
req.query = data.query
else
for k, v in pairs(data.query) do
if not req.query[k] then
req.query[k] = v
end
end
end
end
end
local function batch_requests(ctx)
local metadata = plugin.plugin_metadata(plugin_name)
core.log.info("metadata: ", core.json.delay_encode(metadata))
local max_body_size
if metadata then
max_body_size = metadata.value.max_body_size
else
max_body_size = default_max_body_size
end
local req_body, err = core.request.get_body(max_body_size, ctx)
if err then
-- Nginx doesn't support 417: https://trac.nginx.org/nginx/ticket/2062
-- So always return 413 instead
return 413, { error_msg = err }
end
if not req_body then
return 400, {
error_msg = "no request body, you should give at least one pipeline setting"
}
end
local data, err = core.json.decode(req_body)
if not data then
return 400, {
error_msg = "invalid request body: " .. req_body .. ", err: " .. err
}
end
local code, body = check_input(data)
if code then
return code, body
end
local httpc = http.new()
httpc:set_timeout(data.timeout)
local ok, err = httpc:connect("127.0.0.1", ngx.var.server_port)
if not ok then
return 500, {error_msg = "connect to apisix failed: " .. err}
end
ensure_header_lowercase(data)
set_common_header(data)
set_common_query(data)
local responses, err = httpc:request_pipeline(data.pipeline)
if not responses then
return 400, {error_msg = "request failed: " .. err}
end
local aggregated_resp = {}
for _, resp in ipairs(responses) do
if not resp.status then
core.table.insert(aggregated_resp, {
status = 504,
reason = "upstream timeout"
})
end
local sub_resp = {
status = resp.status,
reason = resp.reason,
headers = resp.headers,
}
if resp.has_body then
local err
sub_resp.body, err = resp:read_body()
if err then
sub_resp.read_body_err = err
core.log.error("read pipeline response body failed: ", err)
else
resp:read_trailers()
end
end
core.table.insert(aggregated_resp, sub_resp)
end
return 200, aggregated_resp
end
function _M.api()
local uri = default_uri
local attr = plugin.plugin_attr(plugin_name)
if attr then
uri = attr.uri or default_uri
end
return {
{
methods = {"POST"},
uri = uri,
handler = batch_requests,
}
}
end
return _M

View File

@@ -0,0 +1,261 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local xml2lua = require("xml2lua")
local xmlhandler = require("xmlhandler.tree")
local template = require("resty.template")
local ngx = ngx
local decode_base64 = ngx.decode_base64
local req_set_body_data = ngx.req.set_body_data
local req_get_uri_args = ngx.req.get_uri_args
local str_format = string.format
local decode_args = ngx.decode_args
local str_find = core.string.find
local type = type
local pcall = pcall
local pairs = pairs
local next = next
local multipart = require("multipart")
local setmetatable = setmetatable
local transform_schema = {
type = "object",
properties = {
input_format = { type = "string",
enum = {"xml", "json", "encoded", "args", "plain", "multipart",}},
template = { type = "string" },
template_is_base64 = { type = "boolean" },
},
required = {"template"},
}
local schema = {
type = "object",
properties = {
request = transform_schema,
response = transform_schema,
},
anyOf = {
{required = {"request"}},
{required = {"response"}},
{required = {"request", "response"}},
},
}
local _M = {
version = 0.1,
priority = 1080,
name = "body-transformer",
schema = schema,
}
local function escape_xml(s)
return s:gsub("&", "&amp;")
:gsub("<", "&lt;")
:gsub(">", "&gt;")
:gsub("'", "&apos;")
:gsub('"', "&quot;")
end
local function escape_json(s)
return core.json.encode(s)
end
local function remove_namespace(tbl)
for k, v in pairs(tbl) do
if type(v) == "table" and next(v) == nil then
v = ""
tbl[k] = v
end
if type(k) == "string" then
local newk = k:match(".*:(.*)")
if newk then
tbl[newk] = v
tbl[k] = nil
end
if type(v) == "table" then
remove_namespace(v)
end
end
end
return tbl
end
local decoders = {
xml = function(data)
local handler = xmlhandler:new()
local parser = xml2lua.parser(handler)
local ok, err = pcall(parser.parse, parser, data)
if ok then
return remove_namespace(handler.root)
else
return nil, err
end
end,
json = function(data)
return core.json.decode(data)
end,
encoded = function(data)
return decode_args(data)
end,
args = function()
return req_get_uri_args()
end,
multipart = function (data, content_type_header)
local res = multipart(data, content_type_header)
return res
end
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
local function transform(conf, body, typ, ctx, request_method)
local out = {}
local _multipart
local format = conf[typ].input_format
local ct = ctx.var.http_content_type
if typ == "response" then
ct = ngx.header.content_type
end
if (body or request_method == "GET") and format ~= "plain" then
local err
if format then
out, err = decoders[format](body, ct)
if format == "multipart" then
_multipart = out
out = out:get_all_with_arrays()
end
if not out then
err = str_format("%s body decode: %s", typ, err)
core.log.error(err, ", body=", body)
return nil, 400, err
end
else
core.log.warn("no input format to parse ", typ, " body")
end
end
local text = conf[typ].template
if (conf[typ].template_is_base64 or (format and format ~= "encoded" and format ~= "args")) then
text = decode_base64(text) or text
end
local ok, render = pcall(template.compile, text)
if not ok then
local err = render
err = str_format("%s template compile: %s", typ, err)
core.log.error(err)
return nil, 503, err
end
setmetatable(out, {__index = {
_ctx = ctx,
_body = body,
_escape_xml = escape_xml,
_escape_json = escape_json,
_multipart = _multipart
}})
local ok, render_out = pcall(render, out)
if not ok then
local err = str_format("%s template rendering: %s", typ, render_out)
core.log.error(err)
return nil, 503, err
end
core.log.info(typ, " body transform output=", render_out)
return render_out
end
local function set_input_format(conf, typ, ct, method)
if method == "GET" then
conf[typ].input_format = "args"
end
if conf[typ].input_format == nil and ct then
if ct:find("text/xml") then
conf[typ].input_format = "xml"
elseif ct:find("application/json") then
conf[typ].input_format = "json"
elseif str_find(ct:lower(), "application/x-www-form-urlencoded", nil, true) then
conf[typ].input_format = "encoded"
elseif str_find(ct:lower(), "multipart/", nil, true) then
conf[typ].input_format = "multipart"
end
end
end
function _M.rewrite(conf, ctx)
if conf.request then
local request_method = ngx.var.request_method
conf = core.table.deepcopy(conf)
ctx.body_transformer_conf = conf
local body = core.request.get_body()
set_input_format(conf, "request", ctx.var.http_content_type, request_method)
local out, status, err = transform(conf, body, "request", ctx, request_method)
if not out then
return status, { message = err }
end
req_set_body_data(out)
end
end
function _M.header_filter(conf, ctx)
if conf.response then
if not ctx.body_transformer_conf then
conf = core.table.deepcopy(conf)
ctx.body_transformer_conf = conf
end
set_input_format(conf, "response", ngx.header.content_type)
core.response.clear_header_as_body_modified()
end
end
function _M.body_filter(_, ctx)
local conf = ctx.body_transformer_conf
if not conf then
return
end
if conf.response then
local body = core.response.hold_body_chunk(ctx)
if ngx.arg[2] == false and not body then
return
end
local out = transform(conf, body, "response", ctx)
if not out then
core.log.error("failed to transform response body: ", body)
return
end
ngx.arg[1] = out
end
end
return _M

View File

@@ -0,0 +1,248 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ngx = ngx
local ngx_re_gmatch = ngx.re.gmatch
local ngx_header = ngx.header
local req_http_version = ngx.req.http_version
local str_sub = string.sub
local ipairs = ipairs
local tonumber = tonumber
local type = type
local is_loaded, brotli = pcall(require, "brotli")
local schema = {
type = "object",
properties = {
types = {
anyOf = {
{
type = "array",
minItems = 1,
items = {
type = "string",
minLength = 1,
},
},
{
enum = {"*"}
}
},
default = {"text/html"}
},
min_length = {
type = "integer",
minimum = 1,
default = 20,
},
mode = {
type = "integer",
minimum = 0,
maximum = 2,
default = 0,
-- 0: MODE_GENERIC (default),
-- 1: MODE_TEXT (for UTF-8 format text input)
-- 2: MODE_FONT (for WOFF 2.0)
},
comp_level = {
type = "integer",
minimum = 0,
maximum = 11,
default = 6,
-- follow the default value from ngx_brotli brotli_comp_level
},
lgwin = {
type = "integer",
default = 19,
-- follow the default value from ngx_brotli brotli_window
enum = {0,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24},
},
lgblock = {
type = "integer",
default = 0,
enum = {0,16,17,18,19,20,21,22,23,24},
},
http_version = {
enum = {1.1, 1.0},
default = 1.1,
},
vary = {
type = "boolean",
}
},
}
local _M = {
version = 0.1,
priority = 996,
name = "brotli",
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
local function create_brotli_compressor(mode, comp_level, lgwin, lgblock)
local options = {
mode = mode,
quality = comp_level,
lgwin = lgwin,
lgblock = lgblock,
}
return brotli.compressor:new(options)
end
local function check_accept_encoding(ctx)
local accept_encoding = core.request.header(ctx, "Accept-Encoding")
-- no Accept-Encoding
if not accept_encoding then
return false
end
-- single Accept-Encoding
if accept_encoding == "*" or accept_encoding == "br" then
return true
end
-- multi Accept-Encoding
local iterator, err = ngx_re_gmatch(accept_encoding,
[[([a-z\*]+)(;q=)?([0-9.]*)?]], "jo")
if not iterator then
core.log.error("gmatch failed, error: ", err)
return false
end
local captures
while true do
captures, err = iterator()
if not captures then
break
end
if err then
core.log.error("iterator failed, error: ", err)
return false
end
if (captures[1] == "br" or captures[1] == "*") and
(not captures[2] or captures[3] ~= "0") then
return true
end
end
return false
end
function _M.header_filter(conf, ctx)
if not is_loaded then
core.log.error("please check the brotli library")
return
end
local allow_encoding = check_accept_encoding(ctx)
if not allow_encoding then
return
end
local content_encoded = ngx_header["Content-Encoding"]
if content_encoded then
-- Don't compress if Content-Encoding is present in upstream data
return
end
local types = conf.types
local content_type = ngx_header["Content-Type"]
if not content_type then
-- Like Nginx, don't compress if Content-Type is missing
return
end
if type(types) == "table" then
local matched = false
local from = core.string.find(content_type, ";")
if from then
content_type = str_sub(content_type, 1, from - 1)
end
for _, ty in ipairs(types) do
if content_type == ty then
matched = true
break
end
end
if not matched then
return
end
end
local content_length = tonumber(ngx_header["Content-Length"])
if content_length then
local min_length = conf.min_length
if content_length < min_length then
return
end
-- Like Nginx, don't check min_length if Content-Length is missing
end
local http_version = req_http_version()
if http_version < conf.http_version then
return
end
if conf.vary then
core.response.add_header("Vary", "Accept-Encoding")
end
local compressor = create_brotli_compressor(conf.mode, conf.comp_level,
conf.lgwin, conf.lgblock)
if not compressor then
core.log.error("failed to create brotli compressor")
return
end
ctx.brotli_matched = true
ctx.compressor = compressor
core.response.clear_header_as_body_modified()
core.response.add_header("Content-Encoding", "br")
end
function _M.body_filter(conf, ctx)
if not ctx.brotli_matched then
return
end
local chunk, eof = ngx.arg[1], ngx.arg[2]
if type(chunk) == "string" and chunk ~= "" then
local encode_chunk = ctx.compressor:compress(chunk)
ngx.arg[1] = encode_chunk .. ctx.compressor:flush()
end
if eof then
-- overwriting the arg[1], results into partial response
ngx.arg[1] = ngx.arg[1] .. ctx.compressor:finish()
end
end
return _M

View File

@@ -0,0 +1,201 @@
--
---- Licensed to the Apache Software Foundation (ASF) under one or more
---- contributor license agreements. See the NOTICE file distributed with
---- this work for additional information regarding copyright ownership.
---- The ASF licenses this file to You under the Apache License, Version 2.0
---- (the "License"); you may not use this file except in compliance with
---- the License. You may obtain a copy of the License at
----
---- http://www.apache.org/licenses/LICENSE-2.0
----
---- Unless required by applicable law or agreed to in writing, software
---- distributed under the License is distributed on an "AS IS" BASIS,
---- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
---- See the License for the specific language governing permissions and
---- limitations under the License.
----
local core = require("apisix.core")
local http = require("resty.http")
local ngx = ngx
local ngx_re_match = ngx.re.match
local CAS_REQUEST_URI = "CAS_REQUEST_URI"
local COOKIE_NAME = "CAS_SESSION"
local COOKIE_PARAMS = "; Path=/; HttpOnly"
local SESSION_LIFETIME = 3600
local STORE_NAME = "cas_sessions"
local store = ngx.shared[STORE_NAME]
local plugin_name = "cas-auth"
local schema = {
type = "object",
properties = {
idp_uri = {type = "string"},
cas_callback_uri = {type = "string"},
logout_uri = {type = "string"},
},
required = {
"idp_uri", "cas_callback_uri", "logout_uri"
}
}
local _M = {
version = 0.1,
priority = 2597,
name = plugin_name,
schema = schema
}
function _M.check_schema(conf)
local check = {"idp_uri"}
core.utils.check_https(check, conf, plugin_name)
return core.schema.check(schema, conf)
end
local function uri_without_ticket(conf, ctx)
return ctx.var.scheme .. "://" .. ctx.var.host .. ":" ..
ctx.var.server_port .. conf.cas_callback_uri
end
local function get_session_id(ctx)
return ctx.var["cookie_" .. COOKIE_NAME]
end
local function set_our_cookie(name, val)
core.response.add_header("Set-Cookie", name .. "=" .. val .. COOKIE_PARAMS)
end
local function first_access(conf, ctx)
local login_uri = conf.idp_uri .. "/login?" ..
ngx.encode_args({ service = uri_without_ticket(conf, ctx) })
core.log.info("first access: ", login_uri,
", cookie: ", ctx.var.http_cookie, ", request_uri: ", ctx.var.request_uri)
set_our_cookie(CAS_REQUEST_URI, ctx.var.request_uri)
core.response.set_header("Location", login_uri)
return ngx.HTTP_MOVED_TEMPORARILY
end
local function with_session_id(conf, ctx, session_id)
-- does the cookie exist in our store?
local user = store:get(session_id);
core.log.info("ticket=", session_id, ", user=", user)
if user == nil then
set_our_cookie(COOKIE_NAME, "deleted; Max-Age=0")
return first_access(conf, ctx)
else
-- refresh the TTL
store:set(session_id, user, SESSION_LIFETIME)
end
end
local function set_store_and_cookie(session_id, user)
-- place cookie into cookie store
local success, err, forcible = store:add(session_id, user, SESSION_LIFETIME)
if success then
if forcible then
core.log.info("CAS cookie store is out of memory")
end
set_our_cookie(COOKIE_NAME, session_id)
else
if err == "no memory" then
core.log.emerg("CAS cookie store is out of memory")
elseif err == "exists" then
core.log.error("Same CAS ticket validated twice, this should never happen!")
else
core.log.error("CAS cookie store: ", err)
end
end
return success
end
local function validate(conf, ctx, ticket)
-- send a request to CAS to validate the ticket
local httpc = http.new()
local res, err = httpc:request_uri(conf.idp_uri ..
"/serviceValidate",
{ query = { ticket = ticket, service = uri_without_ticket(conf, ctx) } })
if res and res.status == ngx.HTTP_OK and res.body ~= nil then
if core.string.find(res.body, "<cas:authenticationSuccess>") then
local m = ngx_re_match(res.body, "<cas:user>(.*?)</cas:user>", "jo");
if m then
return m[1]
end
else
core.log.info("CAS serviceValidate failed: ", res.body)
end
else
core.log.error("validate ticket failed: status=", (res and res.status),
", has_body=", (res and res.body ~= nil or false), ", err=", err)
end
return nil
end
local function validate_with_cas(conf, ctx, ticket)
local user = validate(conf, ctx, ticket)
if user and set_store_and_cookie(ticket, user) then
local request_uri = ctx.var["cookie_" .. CAS_REQUEST_URI]
set_our_cookie(CAS_REQUEST_URI, "deleted; Max-Age=0")
core.log.info("ticket: ", ticket,
", cookie: ", ctx.var.http_cookie, ", request_uri: ", request_uri, ", user=", user)
core.response.set_header("Location", request_uri)
return ngx.HTTP_MOVED_TEMPORARILY
else
return ngx.HTTP_UNAUTHORIZED, {message = "invalid ticket"}
end
end
local function logout(conf, ctx)
local session_id = get_session_id(ctx)
if session_id == nil then
return ngx.HTTP_UNAUTHORIZED
end
core.log.info("logout: ticket=", session_id, ", cookie=", ctx.var.http_cookie)
store:delete(session_id)
set_our_cookie(COOKIE_NAME, "deleted; Max-Age=0")
core.response.set_header("Location", conf.idp_uri .. "/logout")
return ngx.HTTP_MOVED_TEMPORARILY
end
function _M.access(conf, ctx)
local method = core.request.get_method()
local uri = ctx.var.uri
if method == "GET" and uri == conf.logout_uri then
return logout(conf, ctx)
end
if method == "POST" and uri == conf.cas_callback_uri then
local data = core.request.get_body()
local ticket = data:match("<samlp:SessionIndex>(.*)</samlp:SessionIndex>")
if ticket == nil then
return ngx.HTTP_BAD_REQUEST,
{message = "invalid logout request from IdP, no ticket"}
end
core.log.info("Back-channel logout (SLO) from IdP: LogoutRequest: ", data)
local session_id = ticket
local user = store:get(session_id);
if user then
store:delete(session_id)
core.log.info("SLO: user=", user, ", tocket=", ticket)
end
else
local session_id = get_session_id(ctx)
if session_id ~= nil then
return with_session_id(conf, ctx, session_id)
end
local ticket = ctx.var.arg_ticket
if ticket ~= nil and uri == conf.cas_callback_uri then
return validate_with_cas(conf, ctx, ticket)
else
return first_access(conf, ctx)
end
end
end
return _M

View File

@@ -0,0 +1,421 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local require = require
local core = require("apisix.core")
local rr_balancer = require("apisix.balancer.roundrobin")
local plugin = require("apisix.plugin")
local t1k = require "resty.t1k"
local expr = require("resty.expr.v1")
local ngx = ngx
local ngx_now = ngx.now
local string = string
local fmt = string.format
local tostring = tostring
local tonumber = tonumber
local ipairs = ipairs
local plugin_name = "chaitin-waf"
local vars_schema = {
type = "array",
}
local lrucache = core.lrucache.new({
ttl = 300, count = 1024
})
local match_schema = {
type = "array",
items = {
type = "object",
properties = {
vars = vars_schema
}
},
}
local plugin_schema = {
type = "object",
properties = {
mode = {
type = "string",
enum = { "off", "monitor", "block", nil },
default = nil,
},
match = match_schema,
append_waf_resp_header = {
type = "boolean",
default = true
},
append_waf_debug_header = {
type = "boolean",
default = false
},
config = {
type = "object",
properties = {
connect_timeout = {
type = "integer",
},
send_timeout = {
type = "integer",
},
read_timeout = {
type = "integer",
},
req_body_size = {
type = "integer",
},
keepalive_size = {
type = "integer",
},
keepalive_timeout = {
type = "integer",
},
real_client_ip = {
type = "boolean"
}
},
},
},
}
local metadata_schema = {
type = "object",
properties = {
mode = {
type = "string",
enum = { "off", "monitor", "block", nil },
default = nil,
},
nodes = {
type = "array",
items = {
type = "object",
properties = {
host = {
type = "string",
pattern = "^\\*?[0-9a-zA-Z-._\\[\\]:/]+$"
},
port = {
type = "integer",
minimum = 1,
default = 80
},
},
required = { "host" }
},
minItems = 1,
},
config = {
type = "object",
properties = {
connect_timeout = {
type = "integer",
default = 1000 -- milliseconds
},
send_timeout = {
type = "integer",
default = 1000 -- milliseconds
},
read_timeout = {
type = "integer",
default = 1000 -- milliseconds
},
req_body_size = {
type = "integer",
default = 1024 -- milliseconds
},
-- maximum concurrent idle connections to
-- the SafeLine WAF detection service
keepalive_size = {
type = "integer",
default = 256
},
keepalive_timeout = {
type = "integer",
default = 60000 -- milliseconds
},
real_client_ip = {
type = "boolean",
default = true
}
},
default = {},
},
},
required = { "nodes" },
}
local _M = {
version = 0.1,
priority = 2700,
name = plugin_name,
schema = plugin_schema,
metadata_schema = metadata_schema
}
local global_server_picker
local HEADER_CHAITIN_WAF = "X-APISIX-CHAITIN-WAF"
local HEADER_CHAITIN_WAF_ERROR = "X-APISIX-CHAITIN-WAF-ERROR"
local HEADER_CHAITIN_WAF_TIME = "X-APISIX-CHAITIN-WAF-TIME"
local HEADER_CHAITIN_WAF_STATUS = "X-APISIX-CHAITIN-WAF-STATUS"
local HEADER_CHAITIN_WAF_ACTION = "X-APISIX-CHAITIN-WAF-ACTION"
local HEADER_CHAITIN_WAF_SERVER = "X-APISIX-CHAITIN-WAF-SERVER"
local blocked_message = [[{"code": %s, "success":false, ]] ..
[["message": "blocked by Chaitin SafeLine Web Application Firewall", "event_id": "%s"}]]
local warning_message = "chaitin-waf monitor mode: request would have been rejected, event_id: "
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local ok, err = core.schema.check(plugin_schema, conf)
if not ok then
return false, err
end
if conf.match then
for _, m in ipairs(conf.match) do
local ok, err = expr.new(m.vars)
if not ok then
return false, "failed to validate the 'vars' expression: " .. err
end
end
end
return true
end
local function get_healthy_chaitin_server_nodes(metadata, checker)
local nodes = metadata.nodes
local new_nodes = core.table.new(0, #nodes)
for i = 1, #nodes do
local host, port = nodes[i].host, nodes[i].port
new_nodes[host .. ":" .. tostring(port)] = 1
end
return new_nodes
end
local function get_chaitin_server(metadata, ctx)
if not global_server_picker or global_server_picker.upstream ~= metadata.value.nodes then
local up_nodes = get_healthy_chaitin_server_nodes(metadata.value)
if core.table.nkeys(up_nodes) == 0 then
return nil, nil, "no healthy nodes"
end
core.log.info("chaitin-waf nodes: ", core.json.delay_encode(up_nodes))
global_server_picker = rr_balancer.new(up_nodes, metadata.value.nodes)
end
local server = global_server_picker.get(ctx)
local host, port, err = core.utils.parse_addr(server)
if err then
return nil, nil, err
end
return host, port, nil
end
local function check_match(conf, ctx)
if not conf.match or #conf.match == 0 then
return true
end
for _, match in ipairs(conf.match) do
local cache_key = tostring(match.vars)
local exp, err = lrucache(cache_key, nil, function(vars)
return expr.new(vars)
end, match.vars)
if not exp then
local msg = "failed to create match expression for " ..
tostring(match.vars) .. ", err: " .. tostring(err)
return false, msg
end
local matched = exp:eval(ctx.var)
if matched then
return true
end
end
return false
end
local function get_conf(conf, metadata)
local t = {
mode = "block",
real_client_ip = true,
}
if metadata.config then
t.connect_timeout = metadata.config.connect_timeout
t.send_timeout = metadata.config.send_timeout
t.read_timeout = metadata.config.read_timeout
t.req_body_size = metadata.config.req_body_size
t.keepalive_size = metadata.config.keepalive_size
t.keepalive_timeout = metadata.config.keepalive_timeout
t.real_client_ip = metadata.config.real_client_ip or t.real_client_ip
end
if conf.config then
t.connect_timeout = conf.config.connect_timeout
t.send_timeout = conf.config.send_timeout
t.read_timeout = conf.config.read_timeout
t.req_body_size = conf.config.req_body_size
t.keepalive_size = conf.config.keepalive_size
t.keepalive_timeout = conf.config.keepalive_timeout
t.real_client_ip = conf.config.real_client_ip or t.real_client_ip
end
t.mode = conf.mode or metadata.mode or t.mode
return t
end
local function do_access(conf, ctx)
local extra_headers = {}
local metadata = plugin.plugin_metadata(plugin_name)
if not core.table.try_read_attr(metadata, "value", "nodes") then
extra_headers[HEADER_CHAITIN_WAF] = "err"
extra_headers[HEADER_CHAITIN_WAF_ERROR] = "missing metadata"
return 500, nil, extra_headers
end
local host, port, err = get_chaitin_server(metadata, ctx)
if err then
extra_headers[HEADER_CHAITIN_WAF] = "unhealthy"
extra_headers[HEADER_CHAITIN_WAF_ERROR] = tostring(err)
return 500, nil, extra_headers
end
core.log.info("picked chaitin-waf server: ", host, ":", port)
local t = get_conf(conf, metadata.value)
t.host = host
t.port = port
extra_headers[HEADER_CHAITIN_WAF_SERVER] = host
local mode = t.mode or "block"
if mode == "off" then
extra_headers[HEADER_CHAITIN_WAF] = "off"
return nil, nil, extra_headers
end
local match, err = check_match(conf, ctx)
if not match then
if err then
extra_headers[HEADER_CHAITIN_WAF] = "err"
extra_headers[HEADER_CHAITIN_WAF_ERROR] = tostring(err)
return 500, nil, extra_headers
else
extra_headers[HEADER_CHAITIN_WAF] = "no"
return nil, nil, extra_headers
end
end
if t.real_client_ip then
t.client_ip = ctx.var.http_x_forwarded_for or ctx.var.remote_addr
else
t.client_ip = ctx.var.remote_addr
end
local start_time = ngx_now() * 1000
local ok, err, result = t1k.do_access(t, false)
extra_headers[HEADER_CHAITIN_WAF_TIME] = ngx_now() * 1000 - start_time
if not ok then
extra_headers[HEADER_CHAITIN_WAF] = "waf-err"
local err_msg = tostring(err)
if core.string.find(err_msg, "timeout") then
extra_headers[HEADER_CHAITIN_WAF] = "timeout"
end
extra_headers[HEADER_CHAITIN_WAF_ERROR] = tostring(err)
if mode == "monitor" then
core.log.warn("chaitin-waf monitor mode: detected waf error - ", err_msg)
return nil, nil, extra_headers
end
return 500, nil, extra_headers
else
extra_headers[HEADER_CHAITIN_WAF] = "yes"
extra_headers[HEADER_CHAITIN_WAF_ACTION] = "pass"
end
local code = 200
extra_headers[HEADER_CHAITIN_WAF_STATUS] = code
if result and result.status and result.status ~= 200 and result.event_id then
extra_headers[HEADER_CHAITIN_WAF_STATUS] = result.status
extra_headers[HEADER_CHAITIN_WAF_ACTION] = "reject"
if mode == "monitor" then
core.log.warn(warning_message, result.event_id)
return nil, nil, extra_headers
end
core.log.error("request rejected by chaitin-waf, event_id: " .. result.event_id)
return tonumber(result.status),
fmt(blocked_message, result.status, result.event_id) .. "\n",
extra_headers
end
return nil, nil, extra_headers
end
function _M.access(conf, ctx)
local code, msg, extra_headers = do_access(conf, ctx)
if not conf.append_waf_debug_header then
extra_headers[HEADER_CHAITIN_WAF_ERROR] = nil
extra_headers[HEADER_CHAITIN_WAF_SERVER] = nil
end
if conf.append_waf_resp_header then
core.response.set_header(extra_headers)
end
return code, msg
end
function _M.header_filter(conf, ctx)
t1k.do_header_filter()
end
return _M

View File

@@ -0,0 +1,208 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local log_util = require("apisix.utils.log-util")
local core = require("apisix.core")
local http = require("resty.http")
local url = require("net.url")
local math_random = math.random
local tostring = tostring
local plugin_name = "clickhouse-logger"
local batch_processor_manager = bp_manager_mod.new(plugin_name)
local schema = {
type = "object",
properties = {
-- deprecated, use "endpoint_addrs" instead
endpoint_addr = core.schema.uri_def,
endpoint_addrs = {items = core.schema.uri_def, type = "array", minItems = 1},
user = {type = "string", default = ""},
password = {type = "string", default = ""},
database = {type = "string", default = ""},
logtable = {type = "string", default = ""},
timeout = {type = "integer", minimum = 1, default = 3},
name = {type = "string", default = "clickhouse logger"},
ssl_verify = {type = "boolean", default = true},
log_format = {type = "object"},
include_req_body = {type = "boolean", default = false},
include_req_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
include_resp_body = {type = "boolean", default = false},
include_resp_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
}
},
oneOf = {
{required = {"endpoint_addr", "user", "password", "database", "logtable"}},
{required = {"endpoint_addrs", "user", "password", "database", "logtable"}}
},
encrypt_fields = {"password"},
}
local metadata_schema = {
type = "object",
properties = {
log_format = {
type = "object"
}
},
}
local _M = {
version = 0.1,
priority = 398,
name = plugin_name,
schema = batch_processor_manager:wrap_schema(schema),
metadata_schema = metadata_schema,
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local check = {"endpoint_addrs"}
core.utils.check_https(check, conf, plugin_name)
core.utils.check_tls_bool({"ssl_verify"}, conf, plugin_name)
return core.schema.check(schema, conf)
end
local function send_http_data(conf, log_message)
local err_msg
local res = true
local selected_endpoint_addr
if conf.endpoint_addr then
selected_endpoint_addr = conf.endpoint_addr
else
selected_endpoint_addr = conf.endpoint_addrs[math_random(#conf.endpoint_addrs)]
end
local url_decoded = url.parse(selected_endpoint_addr)
local host = url_decoded.host
local port = url_decoded.port
core.log.info("sending a batch logs to ", selected_endpoint_addr)
if not port then
if url_decoded.scheme == "https" then
port = 443
else
port = 80
end
end
local httpc = http.new()
httpc:set_timeout(conf.timeout * 1000)
local ok, err = httpc:connect(host, port)
if not ok then
return false, "failed to connect to host[" .. host .. "] port["
.. tostring(port) .. "] " .. err
end
if url_decoded.scheme == "https" then
ok, err = httpc:ssl_handshake(true, host, conf.ssl_verify)
if not ok then
return false, "failed to perform SSL with host[" .. host .. "] "
.. "port[" .. tostring(port) .. "] " .. err
end
end
local httpc_res, httpc_err = httpc:request({
method = "POST",
path = url_decoded.path,
query = url_decoded.query,
body = "INSERT INTO " .. conf.logtable .." FORMAT JSONEachRow " .. log_message,
headers = {
["Host"] = url_decoded.host,
["Content-Type"] = "application/json",
["X-ClickHouse-User"] = conf.user,
["X-ClickHouse-Key"] = conf.password,
["X-ClickHouse-Database"] = conf.database
}
})
if not httpc_res then
return false, "error while sending data to [" .. host .. "] port["
.. tostring(port) .. "] " .. httpc_err
end
-- some error occurred in the server
if httpc_res.status >= 400 then
res = false
err_msg = "server returned status code[" .. httpc_res.status .. "] host["
.. host .. "] port[" .. tostring(port) .. "] "
.. "body[" .. httpc_res:read_body() .. "]"
end
return res, err_msg
end
function _M.body_filter(conf, ctx)
log_util.collect_body(conf, ctx)
end
function _M.log(conf, ctx)
local entry = log_util.get_log_entry(plugin_name, conf, ctx)
if batch_processor_manager:add_entry(conf, entry) then
return
end
-- Generate a function to be executed by the batch processor
local func = function(entries, batch_max_size)
local data, err
if batch_max_size == 1 then
data, err = core.json.encode(entries[1]) -- encode as single {}
else
local log_table = {}
for i = 1, #entries do
core.table.insert(log_table, core.json.encode(entries[i]))
end
data = core.table.concat(log_table, " ") -- assemble multi items as string "{} {}"
end
if not data then
return false, 'error occurred while encoding the data: ' .. err
end
return send_http_data(conf, data)
end
batch_processor_manager:add_entry_to_new_processor(conf, entry, ctx, func)
end
return _M

View File

@@ -0,0 +1,76 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local require = require
local core = require("apisix.core")
local ok, apisix_ngx_client = pcall(require, "resty.apisix.client")
local tonumber = tonumber
local schema = {
type = "object",
properties = {
max_body_size = {
type = "integer",
minimum = 0,
description = "Maximum message body size in bytes. No restriction when set to 0."
},
},
}
local plugin_name = "client-control"
local _M = {
version = 0.1,
priority = 22000,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.rewrite(conf, ctx)
if not ok then
core.log.error("need to build APISIX-Runtime to support client control")
return 501
end
if conf.max_body_size then
local len = tonumber(core.request.header(ctx, "Content-Length"))
if len then
-- if length is given in the header, check it immediately
if conf.max_body_size ~= 0 and len > conf.max_body_size then
return 413
end
end
-- then check it when reading the body
local ok, err = apisix_ngx_client.set_client_max_body_size(conf.max_body_size)
if not ok then
core.log.error("failed to set client max body size: ", err)
return 503
end
end
end
return _M

View File

@@ -0,0 +1,164 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local ipairs = ipairs
local core = require("apisix.core")
local ngx = ngx
local schema = {
type = "object",
properties = {
type = {
type = "string",
enum = {"consumer_name", "service_id", "route_id", "consumer_group_id"},
default = "consumer_name"
},
blacklist = {
type = "array",
minItems = 1,
items = {type = "string"}
},
whitelist = {
type = "array",
minItems = 1,
items = {type = "string"}
},
allowed_by_methods = {
type = "array",
items = {
type = "object",
properties = {
user = {
type = "string"
},
methods = {
type = "array",
minItems = 1,
items = core.schema.method_schema,
}
}
}
},
rejected_code = {type = "integer", minimum = 200, default = 403},
rejected_msg = {type = "string"}
},
anyOf = {
{required = {"blacklist"}},
{required = {"whitelist"}},
{required = {"allowed_by_methods"}}
},
}
local plugin_name = "consumer-restriction"
local _M = {
version = 0.1,
priority = 2400,
name = plugin_name,
schema = schema,
}
local fetch_val_funcs = {
["route_id"] = function(ctx)
return ctx.route_id
end,
["service_id"] = function(ctx)
return ctx.service_id
end,
["consumer_name"] = function(ctx)
return ctx.consumer_name
end,
["consumer_group_id"] = function (ctx)
return ctx.consumer_group_id
end
}
local function is_include(value, tab)
for k,v in ipairs(tab) do
if v == value then
return true
end
end
return false
end
local function is_method_allowed(allowed_methods, method, user)
for _, value in ipairs(allowed_methods) do
if value.user == user then
for _, allowed_method in ipairs(value.methods) do
if allowed_method == method then
return true
end
end
return false
end
end
return true
end
local function reject(conf)
if conf.rejected_msg then
return conf.rejected_code , { message = conf.rejected_msg }
end
return conf.rejected_code , { message = "The " .. conf.type .. " is forbidden."}
end
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
return true
end
function _M.access(conf, ctx)
local value = fetch_val_funcs[conf.type](ctx)
local method = ngx.req.get_method()
if not value then
local err_msg = "The request is rejected, please check the "
.. conf.type .. " for this request"
return 401, { message = err_msg}
end
core.log.info("value: ", value)
local block = false
local whitelisted = false
if conf.blacklist and #conf.blacklist > 0 then
if is_include(value, conf.blacklist) then
return reject(conf)
end
end
if conf.whitelist and #conf.whitelist > 0 then
whitelisted = is_include(value, conf.whitelist)
if not whitelisted then
block = true
end
end
if conf.allowed_by_methods and #conf.allowed_by_methods > 0 and not whitelisted then
if not is_method_allowed(conf.allowed_by_methods, method, value) then
block = true
end
end
if block then
return reject(conf)
end
end
return _M

View File

@@ -0,0 +1,402 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local plugin = require("apisix.plugin")
local ngx = ngx
local plugin_name = "cors"
local str_find = core.string.find
local re_gmatch = ngx.re.gmatch
local re_compile = require("resty.core.regex").re_match_compile
local re_find = ngx.re.find
local ipairs = ipairs
local origins_pattern = [[^(\*|\*\*|null|\w+://[^,]+(,\w+://[^,]+)*)$]]
local TYPE_ACCESS_CONTROL_ALLOW_ORIGIN = "ACAO"
local TYPE_TIMING_ALLOW_ORIGIN = "TAO"
local lrucache = core.lrucache.new({
type = "plugin",
})
local metadata_schema = {
type = "object",
properties = {
allow_origins = {
type = "object",
additionalProperties = {
type = "string",
pattern = origins_pattern
}
},
},
}
local schema = {
type = "object",
properties = {
allow_origins = {
description =
"you can use '*' to allow all origins when no credentials," ..
"'**' to allow forcefully(it will bring some security risks, be carefully)," ..
"multiple origin use ',' to split. default: *.",
type = "string",
pattern = origins_pattern,
default = "*"
},
allow_methods = {
description =
"you can use '*' to allow all methods when no credentials," ..
"'**' to allow forcefully(it will bring some security risks, be carefully)," ..
"multiple method use ',' to split. default: *.",
type = "string",
default = "*"
},
allow_headers = {
description =
"you can use '*' to allow all header when no credentials," ..
"'**' to allow forcefully(it will bring some security risks, be carefully)," ..
"multiple header use ',' to split. default: *.",
type = "string",
default = "*"
},
expose_headers = {
description =
"multiple header use ',' to split." ..
"If not specified, no custom headers are exposed.",
type = "string"
},
max_age = {
description =
"maximum number of seconds the results can be cached." ..
"-1 means no cached, the max value is depend on browser," ..
"more details plz check MDN. default: 5.",
type = "integer",
default = 5
},
allow_credential = {
description =
"allow client append credential. according to CORS specification," ..
"if you set this option to 'true', you can not use '*' for other options.",
type = "boolean",
default = false
},
allow_origins_by_regex = {
type = "array",
description =
"you can use regex to allow specific origins when no credentials," ..
"for example use [.*\\.test.com$] to allow a.test.com and b.test.com",
items = {
type = "string",
minLength = 1,
maxLength = 4096,
},
minItems = 1,
uniqueItems = true,
},
allow_origins_by_metadata = {
type = "array",
description =
"set allowed origins by referencing origins in plugin metadata",
items = {
type = "string",
minLength = 1,
maxLength = 4096,
},
minItems = 1,
uniqueItems = true,
},
timing_allow_origins = {
description =
"you can use '*' to allow all origins which can view timing information " ..
"when no credentials," ..
"'**' to allow forcefully (it will bring some security risks, be careful)," ..
"multiple origin use ',' to split. default: nil",
type = "string",
pattern = origins_pattern
},
timing_allow_origins_by_regex = {
type = "array",
description =
"you can use regex to allow specific origins which can view timing information," ..
"for example use [.*\\.test.com] to allow a.test.com and b.test.com",
items = {
type = "string",
minLength = 1,
maxLength = 4096,
},
minItems = 1,
uniqueItems = true,
},
}
}
local _M = {
version = 0.1,
priority = 4000,
name = plugin_name,
schema = schema,
metadata_schema = metadata_schema,
}
local function create_multiple_origin_cache(allow_origins)
if not str_find(allow_origins, ",") then
return nil
end
local origin_cache = {}
local iterator, err = re_gmatch(allow_origins, "([^,]+)", "jiox")
if not iterator then
core.log.error("match origins failed: ", err)
return nil
end
while true do
local origin, err = iterator()
if err then
core.log.error("iterate origins failed: ", err)
return nil
end
if not origin then
break
end
origin_cache[origin[0]] = true
end
return origin_cache
end
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
if conf.allow_credential then
if conf.allow_origins == "*" or conf.allow_methods == "*" or
conf.allow_headers == "*" or conf.expose_headers == "*" or
conf.timing_allow_origins == "*" then
return false, "you can not set '*' for other option when 'allow_credential' is true"
end
end
if conf.allow_origins_by_regex then
for i, re_rule in ipairs(conf.allow_origins_by_regex) do
local ok, err = re_compile(re_rule, "j")
if not ok then
return false, err
end
end
end
if conf.timing_allow_origins_by_regex then
for i, re_rule in ipairs(conf.timing_allow_origins_by_regex) do
local ok, err = re_compile(re_rule, "j")
if not ok then
return false, err
end
end
end
return true
end
local function set_cors_headers(conf, ctx)
local allow_methods = conf.allow_methods
if allow_methods == "**" then
allow_methods = "GET,POST,PUT,DELETE,PATCH,HEAD,OPTIONS,CONNECT,TRACE"
end
core.response.set_header("Access-Control-Allow-Origin", ctx.cors_allow_origins)
core.response.set_header("Access-Control-Allow-Methods", allow_methods)
core.response.set_header("Access-Control-Max-Age", conf.max_age)
if conf.expose_headers ~= nil and conf.expose_headers ~= "" then
core.response.set_header("Access-Control-Expose-Headers", conf.expose_headers)
end
if conf.allow_headers == "**" then
core.response.set_header("Access-Control-Allow-Headers",
core.request.header(ctx, "Access-Control-Request-Headers"))
else
core.response.set_header("Access-Control-Allow-Headers", conf.allow_headers)
end
if conf.allow_credential then
core.response.set_header("Access-Control-Allow-Credentials", true)
end
end
local function set_timing_headers(conf, ctx)
if ctx.timing_allow_origin then
core.response.set_header("Timing-Allow-Origin", ctx.timing_allow_origin)
end
end
local function process_with_allow_origins(allow_origin_type, allow_origins, ctx, req_origin,
cache_key, cache_version)
if allow_origins == "**" then
allow_origins = req_origin or '*'
end
local multiple_origin, err
if cache_key and cache_version then
multiple_origin, err = lrucache(
cache_key, cache_version, create_multiple_origin_cache, allow_origins
)
else
multiple_origin, err = core.lrucache.plugin_ctx(
lrucache, ctx, allow_origin_type, create_multiple_origin_cache, allow_origins
)
end
if err then
return 500, {message = "get multiple origin cache failed: " .. err}
end
if multiple_origin then
if multiple_origin[req_origin] then
allow_origins = req_origin
else
return
end
end
return allow_origins
end
local function process_with_allow_origins_by_regex(allow_origin_type,
allow_origins_by_regex, conf, ctx, req_origin)
local allow_origins_by_regex_rules_concat_conf_key =
"allow_origins_by_regex_rules_concat_" .. allow_origin_type
if not conf[allow_origins_by_regex_rules_concat_conf_key] then
local allow_origins_by_regex_rules = {}
for i, re_rule in ipairs(allow_origins_by_regex) do
allow_origins_by_regex_rules[i] = re_rule
end
conf[allow_origins_by_regex_rules_concat_conf_key] = core.table.concat(
allow_origins_by_regex_rules, "|")
end
-- core.log.warn("regex: ", conf[allow_origins_by_regex_rules_concat_conf_key], "\n ")
local matched = re_find(req_origin, conf[allow_origins_by_regex_rules_concat_conf_key], "jo")
if matched then
return req_origin
end
end
local function match_origins(req_origin, allow_origins)
return req_origin == allow_origins or allow_origins == '*'
end
local function process_with_allow_origins_by_metadata(allow_origin_type, allow_origins_by_metadata,
ctx, req_origin)
if allow_origins_by_metadata == nil then
return
end
local metadata = plugin.plugin_metadata(plugin_name)
if metadata and metadata.value.allow_origins then
local allow_origins_map = metadata.value.allow_origins
for _, key in ipairs(allow_origins_by_metadata) do
local allow_origins_conf = allow_origins_map[key]
local allow_origins = process_with_allow_origins(
allow_origin_type, allow_origins_conf, ctx, req_origin,
plugin_name .. "#" .. key, metadata.modifiedIndex
)
if match_origins(req_origin, allow_origins) then
return req_origin
end
end
end
end
function _M.rewrite(conf, ctx)
-- save the original request origin as it may be changed at other phase
ctx.original_request_origin = core.request.header(ctx, "Origin")
if ctx.var.request_method == "OPTIONS" then
return 200
end
end
function _M.header_filter(conf, ctx)
local req_origin = ctx.original_request_origin
-- If allow_origins_by_regex is not nil, should be matched to it only
local allow_origins
local allow_origins_local = false
if conf.allow_origins_by_metadata then
allow_origins = process_with_allow_origins_by_metadata(
TYPE_ACCESS_CONTROL_ALLOW_ORIGIN, conf.allow_origins_by_metadata, ctx, req_origin
)
if not match_origins(req_origin, allow_origins) then
if conf.allow_origins and conf.allow_origins ~= "*" then
allow_origins_local = true
end
end
else
allow_origins_local = true
end
if conf.allow_origins_by_regex == nil then
if allow_origins_local then
allow_origins = process_with_allow_origins(
TYPE_ACCESS_CONTROL_ALLOW_ORIGIN, conf.allow_origins, ctx, req_origin
)
end
else
if allow_origins_local then
allow_origins = process_with_allow_origins_by_regex(
TYPE_ACCESS_CONTROL_ALLOW_ORIGIN, conf.allow_origins_by_regex,
conf, ctx, req_origin
)
end
end
if not match_origins(req_origin, allow_origins) then
allow_origins = process_with_allow_origins_by_metadata(
TYPE_ACCESS_CONTROL_ALLOW_ORIGIN, conf.allow_origins_by_metadata, ctx, req_origin
)
end
if conf.allow_origins ~= "*" then
core.response.add_header("Vary", "Origin")
end
if allow_origins then
ctx.cors_allow_origins = allow_origins
set_cors_headers(conf, ctx)
end
local timing_allow_origins
if conf.timing_allow_origins_by_regex == nil and conf.timing_allow_origins then
timing_allow_origins = process_with_allow_origins(
TYPE_TIMING_ALLOW_ORIGIN, conf.timing_allow_origins, ctx, req_origin
)
elseif conf.timing_allow_origins_by_regex then
timing_allow_origins = process_with_allow_origins_by_regex(
TYPE_TIMING_ALLOW_ORIGIN, conf.timing_allow_origins_by_regex,
conf, ctx, req_origin
)
end
if timing_allow_origins and match_origins(req_origin, timing_allow_origins) then
ctx.timing_allow_origin = timing_allow_origins
set_timing_headers(conf, ctx)
end
end
return _M

View File

@@ -0,0 +1,168 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local resty_sha256 = require("resty.sha256")
local str = require("resty.string")
local ngx = ngx
local ngx_encode_base64 = ngx.encode_base64
local ngx_decode_base64 = ngx.decode_base64
local ngx_time = ngx.time
local ngx_cookie_time = ngx.cookie_time
local math = math
local SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
local schema = {
type = "object",
properties = {
key = {
description = "use to generate csrf token",
type = "string",
},
expires = {
description = "expires time(s) for csrf token",
type = "integer",
default = 7200
},
name = {
description = "the csrf token name",
type = "string",
default = "apisix-csrf-token"
}
},
encrypt_fields = {"key"},
required = {"key"}
}
local _M = {
version = 0.1,
priority = 2980,
name = "csrf",
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
local function gen_sign(random, expires, key)
local sha256 = resty_sha256:new()
local sign = "{expires:" .. expires .. ",random:" .. random .. ",key:" .. key .. "}"
sha256:update(sign)
local digest = sha256:final()
return str.to_hex(digest)
end
local function gen_csrf_token(conf)
local random = math.random()
local timestamp = ngx_time()
local sign = gen_sign(random, timestamp, conf.key)
local token = {
random = random,
expires = timestamp,
sign = sign,
}
local cookie = ngx_encode_base64(core.json.encode(token))
return cookie
end
local function check_csrf_token(conf, ctx, token)
local token_str = ngx_decode_base64(token)
if not token_str then
core.log.error("csrf token base64 decode error")
return false
end
local token_table, err = core.json.decode(token_str)
if err then
core.log.error("decode token error: ", err)
return false
end
local random = token_table["random"]
if not random then
core.log.error("no random in token")
return false
end
local expires = token_table["expires"]
if not expires then
core.log.error("no expires in token")
return false
end
local time_now = ngx_time()
if conf.expires > 0 and time_now - expires > conf.expires then
core.log.error("token has expired")
return false
end
local sign = gen_sign(random, expires, conf.key)
if token_table["sign"] ~= sign then
core.log.error("Invalid signatures")
return false
end
return true
end
function _M.access(conf, ctx)
local method = core.request.get_method(ctx)
if core.table.array_find(SAFE_METHODS, method) then
return
end
local header_token = core.request.header(ctx, conf.name)
if not header_token or header_token == "" then
return 401, {error_msg = "no csrf token in headers"}
end
local cookie_token = ctx.var["cookie_" .. conf.name]
if not cookie_token then
return 401, {error_msg = "no csrf cookie"}
end
if header_token ~= cookie_token then
return 401, {error_msg = "csrf token mismatch"}
end
local result = check_csrf_token(conf, ctx, cookie_token)
if not result then
return 401, {error_msg = "Failed to verify the csrf token signature"}
end
end
function _M.header_filter(conf, ctx)
local csrf_token = gen_csrf_token(conf)
local cookie = conf.name .. "=" .. csrf_token .. ";path=/;SameSite=Lax;Expires="
.. ngx_cookie_time(ngx_time() + conf.expires)
core.response.add_header("Set-Cookie", cookie)
end
return _M

View File

@@ -0,0 +1,251 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
local core = require("apisix.core")
local plugin = require("apisix.plugin")
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local fetch_log = require("apisix.utils.log-util").get_full_log
local service_fetch = require("apisix.http.service").get
local ngx = ngx
local udp = ngx.socket.udp
local format = string.format
local concat = table.concat
local tostring = tostring
local plugin_name = "datadog"
local defaults = {
host = "127.0.0.1",
port = 8125,
namespace = "apisix",
constant_tags = {"source:apisix"}
}
local batch_processor_manager = bp_manager_mod.new(plugin_name)
local schema = {
type = "object",
properties = {
prefer_name = {type = "boolean", default = true}
}
}
local metadata_schema = {
type = "object",
properties = {
host = {type = "string", default= defaults.host},
port = {type = "integer", minimum = 0, default = defaults.port},
namespace = {type = "string", default = defaults.namespace},
constant_tags = {
type = "array",
items = {type = "string"},
default = defaults.constant_tags
}
},
}
local _M = {
version = 0.1,
priority = 495,
name = plugin_name,
schema = batch_processor_manager:wrap_schema(schema),
metadata_schema = metadata_schema,
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
return core.schema.check(schema, conf)
end
local function generate_tag(entry, const_tags)
local tags
if const_tags and #const_tags > 0 then
tags = core.table.clone(const_tags)
else
tags = {}
end
if entry.route_id and entry.route_id ~= "" then
core.table.insert(tags, "route_name:" .. entry.route_id)
end
if entry.service_id and entry.service_id ~= "" then
core.table.insert(tags, "service_name:" .. entry.service_id)
end
if entry.consumer and entry.consumer.username then
core.table.insert(tags, "consumer:" .. entry.consumer.username)
end
if entry.balancer_ip ~= "" then
core.table.insert(tags, "balancer_ip:" .. entry.balancer_ip)
end
if entry.response.status then
core.table.insert(tags, "response_status:" .. entry.response.status)
end
if entry.scheme ~= "" then
core.table.insert(tags, "scheme:" .. entry.scheme)
end
if #tags > 0 then
return "|#" .. concat(tags, ',')
end
return ""
end
local function send_metric_over_udp(entry, metadata)
local err_msg
local sock = udp()
local host, port = metadata.value.host, metadata.value.port
local ok, err = sock:setpeername(host, port)
if not ok then
return false, "failed to connect to UDP server: host[" .. host
.. "] port[" .. tostring(port) .. "] err: " .. err
end
-- Generate prefix & suffix according dogstatsd udp data format.
local suffix = generate_tag(entry, metadata.value.constant_tags)
local prefix = metadata.value.namespace
if prefix ~= "" then
prefix = prefix .. "."
end
-- request counter
ok, err = sock:send(format("%s:%s|%s%s", prefix .. "request.counter", 1, "c", suffix))
if not ok then
err_msg = "error sending request.counter: " .. err
core.log.error("failed to report request count to dogstatsd server: host[" .. host
.. "] port[" .. tostring(port) .. "] err: " .. err)
end
-- request latency histogram
ok, err = sock:send(format("%s:%s|%s%s", prefix .. "request.latency",
entry.latency, "h", suffix))
if not ok then
err_msg = "error sending request.latency: " .. err
core.log.error("failed to report request latency to dogstatsd server: host["
.. host .. "] port[" .. tostring(port) .. "] err: " .. err)
end
-- upstream latency
if entry.upstream_latency then
ok, err = sock:send(format("%s:%s|%s%s", prefix .. "upstream.latency",
entry.upstream_latency, "h", suffix))
if not ok then
err_msg = "error sending upstream.latency: " .. err
core.log.error("failed to report upstream latency to dogstatsd server: host["
.. host .. "] port[" .. tostring(port) .. "] err: " .. err)
end
end
-- apisix_latency
ok, err = sock:send(format("%s:%s|%s%s", prefix .. "apisix.latency",
entry.apisix_latency, "h", suffix))
if not ok then
err_msg = "error sending apisix.latency: " .. err
core.log.error("failed to report apisix latency to dogstatsd server: host[" .. host
.. "] port[" .. tostring(port) .. "] err: " .. err)
end
-- request body size timer
ok, err = sock:send(format("%s:%s|%s%s", prefix .. "ingress.size",
entry.request.size, "ms", suffix))
if not ok then
err_msg = "error sending ingress.size: " .. err
core.log.error("failed to report req body size to dogstatsd server: host[" .. host
.. "] port[" .. tostring(port) .. "] err: " .. err)
end
-- response body size timer
ok, err = sock:send(format("%s:%s|%s%s", prefix .. "egress.size",
entry.response.size, "ms", suffix))
if not ok then
err_msg = "error sending egress.size: " .. err
core.log.error("failed to report response body size to dogstatsd server: host["
.. host .. "] port[" .. tostring(port) .. "] err: " .. err)
end
ok, err = sock:close()
if not ok then
core.log.error("failed to close the UDP connection, host[",
host, "] port[", port, "] ", err)
end
if not err_msg then
return true
end
return false, err_msg
end
local function push_metrics(entries)
-- Fetching metadata details
local metadata = plugin.plugin_metadata(plugin_name)
core.log.info("metadata: ", core.json.delay_encode(metadata))
if not metadata then
core.log.info("received nil metadata: using metadata defaults: ",
core.json.delay_encode(defaults, true))
metadata = {}
metadata.value = defaults
end
core.log.info("sending batch metrics to dogstatsd: ", metadata.value.host,
":", metadata.value.port)
for i = 1, #entries do
local ok, err = send_metric_over_udp(entries[i], metadata)
if not ok then
return false, err, i
end
end
return true
end
function _M.log(conf, ctx)
local entry = fetch_log(ngx, {})
entry.balancer_ip = ctx.balancer_ip or ""
entry.scheme = ctx.upstream_scheme or ""
-- if prefer_name is set, fetch the service/route name. If the name is nil, fall back to id.
if conf.prefer_name then
if entry.service_id and entry.service_id ~= "" then
local svc = service_fetch(entry.service_id)
if svc and svc.value.name ~= "" then
entry.service_id = svc.value.name
end
end
if ctx.route_name and ctx.route_name ~= "" then
entry.route_id = ctx.route_name
end
end
if batch_processor_manager:add_entry(conf, entry) then
return
end
batch_processor_manager:add_entry_to_new_processor(conf, entry, ctx, push_metrics)
end
return _M

View File

@@ -0,0 +1,160 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local gq_parse = require("graphql").parse
local req_set_body_data = ngx.req.set_body_data
local ipairs = ipairs
local pcall = pcall
local type = type
local schema = {
type = "object",
properties = {
query = {
type = "string",
minLength = 1,
maxLength = 1024,
},
variables = {
type = "array",
items = {
type = "string"
},
minItems = 1,
},
operation_name = {
type = "string",
minLength = 1,
maxLength = 1024
},
},
required = {"query"},
}
local plugin_name = "degraphql"
local _M = {
version = 0.1,
priority = 509,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
local ok, res = pcall(gq_parse, conf.query)
if not ok then
return false, "failed to parse query: " .. res
end
if #res.definitions > 1 and not conf.operation_name then
return false, "operation_name is required if multiple operations are present in the query"
end
return true
end
local function fetch_post_variables(conf)
local req_body, err = core.request.get_body()
if err ~= nil then
core.log.error("failed to get request body: ", err)
return nil, 503
end
if not req_body then
core.log.error("missing request body")
return nil, 400
end
-- JSON as the default content type
req_body, err = core.json.decode(req_body)
if type(req_body) ~= "table" then
core.log.error("invalid request body can't be decoded: ", err or "bad type")
return nil, 400
end
local variables = {}
for _, v in ipairs(conf.variables) do
variables[v] = req_body[v]
end
return variables
end
local function fetch_get_variables(conf)
local args = core.request.get_uri_args()
local variables = {}
for _, v in ipairs(conf.variables) do
variables[v] = args[v]
end
return variables
end
function _M.access(conf, ctx)
local meth = core.request.get_method()
if meth ~= "POST" and meth ~= "GET" then
return 405
end
local new_body = core.table.new(0, 3)
if conf.variables then
local variables, code
if meth == "POST" then
variables, code = fetch_post_variables(conf)
else
variables, code = fetch_get_variables(conf)
end
if not variables then
return code
end
if meth == "POST" then
new_body["variables"] = variables
else
new_body["variables"] = core.json.encode(variables)
end
end
new_body["operationName"] = conf.operation_name
new_body["query"] = conf.query
if meth == "POST" then
if not conf.variables then
-- the set_body_data requires to read the body first
core.request.get_body()
end
core.request.set_header(ctx, "Content-Type", "application/json")
req_set_body_data(core.json.encode(new_body))
else
core.request.set_uri_args(ctx, new_body)
end
end
return _M

View File

@@ -0,0 +1,69 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ngx_var = ngx.var
local plugin_name = "dubbo-proxy"
local schema = {
type = "object",
properties = {
service_name = {
type = "string",
minLength = 1,
},
service_version = {
type = "string",
pattern = [[^\d+\.\d+\.\d+]],
},
method = {
type = "string",
minLength = 1,
},
},
required = { "service_name", "service_version"},
}
local _M = {
version = 0.1,
priority = 507,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.access(conf, ctx)
ctx.dubbo_proxy_enabled = true
ngx_var.dubbo_service_name = conf.service_name
ngx_var.dubbo_service_version = conf.service_version
if not conf.method then
-- remove the prefix '/' from $uri
ngx_var.dubbo_method = core.string.sub(ngx_var.uri, 2)
else
ngx_var.dubbo_method = conf.method
end
end
return _M

View File

@@ -0,0 +1,121 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local pairs = pairs
local type = type
local ngx = ngx
local schema = {
type = "object",
properties = {
before_body = {
description = "body before the filter phase.",
type = "string"
},
body = {
description = "body to replace upstream response.",
type = "string"
},
after_body = {
description = "body after the modification of filter phase.",
type = "string"
},
headers = {
description = "new headers for response",
type = "object",
minProperties = 1,
},
},
anyOf = {
{required = {"before_body"}},
{required = {"body"}},
{required = {"after_body"}}
},
minProperties = 1,
}
local plugin_name = "echo"
local _M = {
version = 0.1,
priority = 412,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
return true
end
function _M.body_filter(conf, ctx)
if conf.body then
ngx.arg[1] = conf.body
ngx.arg[2] = true
end
if conf.before_body and not ctx.plugin_echo_body_set then
ngx.arg[1] = conf.before_body .. ngx.arg[1]
ctx.plugin_echo_body_set = true
end
if ngx.arg[2] and conf.after_body then
ngx.arg[1] = ngx.arg[1] .. conf.after_body
end
end
function _M.header_filter(conf, ctx)
if conf.body or conf.before_body or conf.after_body then
core.response.clear_header_as_body_modified()
end
if not conf.headers then
return
end
if not conf.headers_arr then
conf.headers_arr = {}
for field, value in pairs(conf.headers) do
if type(field) == 'string'
and (type(value) == 'string' or type(value) == 'number') then
if #field == 0 then
return false, 'invalid field length in header'
end
core.table.insert(conf.headers_arr, field)
core.table.insert(conf.headers_arr, value)
else
return false, 'invalid type as header value'
end
end
end
local field_cnt = #conf.headers_arr
for i = 1, field_cnt, 2 do
ngx.header[conf.headers_arr[i]] = conf.headers_arr[i+1]
end
end
return _M

View File

@@ -0,0 +1,281 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local http = require("resty.http")
local log_util = require("apisix.utils.log-util")
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local ngx = ngx
local str_format = core.string.format
local math_random = math.random
local plugin_name = "elasticsearch-logger"
local batch_processor_manager = bp_manager_mod.new(plugin_name)
local schema = {
type = "object",
properties = {
-- deprecated, use "endpoint_addrs" instead
endpoint_addr = {
type = "string",
pattern = "[^/]$",
},
endpoint_addrs = {
type = "array",
minItems = 1,
items = {
type = "string",
pattern = "[^/]$",
},
},
field = {
type = "object",
properties = {
index = { type = "string"},
},
required = {"index"}
},
log_format = {type = "object"},
auth = {
type = "object",
properties = {
username = {
type = "string",
minLength = 1
},
password = {
type = "string",
minLength = 1
},
},
required = {"username", "password"},
},
timeout = {
type = "integer",
minimum = 1,
default = 10
},
ssl_verify = {
type = "boolean",
default = true
},
include_req_body = {type = "boolean", default = false},
include_req_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
include_resp_body = { type = "boolean", default = false },
include_resp_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
},
encrypt_fields = {"auth.password"},
oneOf = {
{required = {"endpoint_addr", "field"}},
{required = {"endpoint_addrs", "field"}}
},
}
local metadata_schema = {
type = "object",
properties = {
log_format = {
type = "object"
}
},
}
local _M = {
version = 0.1,
priority = 413,
name = plugin_name,
schema = batch_processor_manager:wrap_schema(schema),
metadata_schema = metadata_schema,
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local check = {"endpoint_addrs"}
core.utils.check_https(check, conf, plugin_name)
core.utils.check_tls_bool({"ssl_verify"}, conf, plugin_name)
return core.schema.check(schema, conf)
end
local function get_es_major_version(uri, conf)
local httpc = http.new()
if not httpc then
return nil, "failed to create http client"
end
local headers = {}
if conf.auth then
local authorization = "Basic " .. ngx.encode_base64(
conf.auth.username .. ":" .. conf.auth.password
)
headers["Authorization"] = authorization
end
httpc:set_timeout(conf.timeout * 1000)
local res, err = httpc:request_uri(uri, {
ssl_verify = conf.ssl_verify,
method = "GET",
headers = headers,
})
if not res then
return false, err
end
if res.status ~= 200 then
return nil, str_format("server returned status: %d, body: %s",
res.status, res.body or "")
end
local json_body, err = core.json.decode(res.body)
if not json_body then
return nil, "failed to decode response body: " .. err
end
if not json_body.version or not json_body.version.number then
return nil, "failed to get version from response body"
end
local major_version = json_body.version.number:match("^(%d+)%.")
if not major_version then
return nil, "invalid version format: " .. json_body.version.number
end
return major_version
end
local function get_logger_entry(conf, ctx)
local entry = log_util.get_log_entry(plugin_name, conf, ctx)
local body = {
index = {
_index = conf.field.index
}
}
-- for older version type is required
if conf._version == "6" or conf._version == "5" then
body.index._type = "_doc"
end
return core.json.encode(body) .. "\n" ..
core.json.encode(entry) .. "\n"
end
local function fetch_and_update_es_version(conf)
if conf._version then
return
end
local selected_endpoint_addr
if conf.endpoint_addr then
selected_endpoint_addr = conf.endpoint_addr
else
selected_endpoint_addr = conf.endpoint_addrs[math_random(#conf.endpoint_addrs)]
end
local major_version, err = get_es_major_version(selected_endpoint_addr, conf)
if err then
core.log.error("failed to get Elasticsearch version: ", err)
return
end
conf._version = major_version
end
local function send_to_elasticsearch(conf, entries)
local httpc, err = http.new()
if not httpc then
return false, str_format("create http error: %s", err)
end
fetch_and_update_es_version(conf)
local selected_endpoint_addr
if conf.endpoint_addr then
selected_endpoint_addr = conf.endpoint_addr
else
selected_endpoint_addr = conf.endpoint_addrs[math_random(#conf.endpoint_addrs)]
end
local uri = selected_endpoint_addr .. "/_bulk"
local body = core.table.concat(entries, "")
local headers = {
["Content-Type"] = "application/x-ndjson",
["Accept"] = "application/vnd.elasticsearch+json"
}
if conf.auth then
local authorization = "Basic " .. ngx.encode_base64(
conf.auth.username .. ":" .. conf.auth.password
)
headers["Authorization"] = authorization
end
core.log.info("uri: ", uri, ", body: ", body)
httpc:set_timeout(conf.timeout * 1000)
local resp, err = httpc:request_uri(uri, {
ssl_verify = conf.ssl_verify,
method = "POST",
headers = headers,
body = body
})
if not resp then
return false, err
end
if resp.status ~= 200 then
return false, str_format("elasticsearch server returned status: %d, body: %s",
resp.status, resp.body or "")
end
return true
end
function _M.body_filter(conf, ctx)
log_util.collect_body(conf, ctx)
end
function _M.access(conf)
-- fetch_and_update_es_version will call ES server only the first time
-- so this should not amount to considerable overhead
fetch_and_update_es_version(conf)
end
function _M.log(conf, ctx)
local entry = get_logger_entry(conf, ctx)
if batch_processor_manager:add_entry(conf, entry) then
return
end
local process = function(entries)
return send_to_elasticsearch(conf, entries)
end
batch_processor_manager:add_entry_to_new_processor(conf, entry, ctx, process)
end
return _M

View File

@@ -0,0 +1,510 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local errlog = require("ngx.errlog")
local batch_processor = require("apisix.utils.batch-processor")
local plugin = require("apisix.plugin")
local timers = require("apisix.timers")
local http = require("resty.http")
local producer = require("resty.kafka.producer")
local plugin_name = "error-log-logger"
local table = core.table
local schema_def = core.schema
local ngx = ngx
local tcp = ngx.socket.tcp
local tostring = tostring
local ipairs = ipairs
local string = require("string")
local lrucache = core.lrucache.new({
ttl = 300, count = 32
})
local kafka_prod_lrucache = core.lrucache.new({
ttl = 300, count = 32
})
local metadata_schema = {
type = "object",
properties = {
tcp = {
type = "object",
properties = {
host = schema_def.host_def,
port = {type = "integer", minimum = 0},
tls = {type = "boolean", default = false},
tls_server_name = {type = "string"},
},
required = {"host", "port"}
},
skywalking = {
type = "object",
properties = {
endpoint_addr = {schema_def.uri, default = "http://127.0.0.1:12900/v3/logs"},
service_name = {type = "string", default = "APISIX"},
service_instance_name = {type="string", default = "APISIX Service Instance"},
},
},
clickhouse = {
type = "object",
properties = {
endpoint_addr = {schema_def.uri_def, default="http://127.0.0.1:8123"},
user = {type = "string", default = "default"},
password = {type = "string", default = ""},
database = {type = "string", default = ""},
logtable = {type = "string", default = ""},
},
required = {"endpoint_addr", "user", "password", "database", "logtable"}
},
kafka = {
type = "object",
properties = {
brokers = {
type = "array",
minItems = 1,
items = {
type = "object",
properties = {
host = {
type = "string",
description = "the host of kafka broker",
},
port = {
type = "integer",
minimum = 1,
maximum = 65535,
description = "the port of kafka broker",
},
sasl_config = {
type = "object",
description = "sasl config",
properties = {
mechanism = {
type = "string",
default = "PLAIN",
enum = {"PLAIN"},
},
user = { type = "string", description = "user" },
password = { type = "string", description = "password" },
},
required = {"user", "password"},
},
},
required = {"host", "port"},
},
uniqueItems = true,
},
kafka_topic = {type = "string"},
producer_type = {
type = "string",
default = "async",
enum = {"async", "sync"},
},
required_acks = {
type = "integer",
default = 1,
enum = { 0, 1, -1 },
},
key = {type = "string"},
-- in lua-resty-kafka, cluster_name is defined as number
-- see https://github.com/doujiang24/lua-resty-kafka#new-1
cluster_name = {type = "integer", minimum = 1, default = 1},
meta_refresh_interval = {type = "integer", minimum = 1, default = 30},
},
required = {"brokers", "kafka_topic"},
},
name = {type = "string", default = plugin_name},
level = {type = "string", default = "WARN", enum = {"STDERR", "EMERG", "ALERT", "CRIT",
"ERR", "ERROR", "WARN", "NOTICE", "INFO", "DEBUG"}},
timeout = {type = "integer", minimum = 1, default = 3},
keepalive = {type = "integer", minimum = 1, default = 30},
batch_max_size = {type = "integer", minimum = 0, default = 1000},
max_retry_count = {type = "integer", minimum = 0, default = 0},
retry_delay = {type = "integer", minimum = 0, default = 1},
buffer_duration = {type = "integer", minimum = 1, default = 60},
inactive_timeout = {type = "integer", minimum = 1, default = 3},
},
oneOf = {
{required = {"skywalking"}},
{required = {"tcp"}},
{required = {"clickhouse"}},
{required = {"kafka"}},
-- for compatible with old schema
{required = {"host", "port"}}
},
encrypt_fields = {"clickhouse.password"},
}
local schema = {
type = "object",
}
local log_level = {
STDERR = ngx.STDERR,
EMERG = ngx.EMERG,
ALERT = ngx.ALERT,
CRIT = ngx.CRIT,
ERR = ngx.ERR,
ERROR = ngx.ERR,
WARN = ngx.WARN,
NOTICE = ngx.NOTICE,
INFO = ngx.INFO,
DEBUG = ngx.DEBUG
}
local config = {}
local log_buffer
local _M = {
version = 0.1,
priority = 1091,
name = plugin_name,
schema = schema,
metadata_schema = metadata_schema,
scope = "global",
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local check = {"skywalking.endpoint_addr", "clickhouse.endpoint_addr"}
core.utils.check_https(check, conf, plugin_name)
core.utils.check_tls_bool({"tcp.tls"}, conf, plugin_name)
return core.schema.check(schema, conf)
end
local function send_to_tcp_server(data)
local sock, soc_err = tcp()
if not sock then
return false, "failed to init the socket " .. soc_err
end
sock:settimeout(config.timeout * 1000)
local tcp_config = config.tcp
local ok, err = sock:connect(tcp_config.host, tcp_config.port)
if not ok then
return false, "failed to connect the TCP server: host[" .. tcp_config.host
.. "] port[" .. tostring(tcp_config.port) .. "] err: " .. err
end
if tcp_config.tls then
ok, err = sock:sslhandshake(false, tcp_config.tls_server_name, false)
if not ok then
sock:close()
return false, "failed to perform TLS handshake to TCP server: host["
.. tcp_config.host .. "] port[" .. tostring(tcp_config.port) .. "] err: " .. err
end
end
local bytes, err = sock:send(data)
if not bytes then
sock:close()
return false, "failed to send data to TCP server: host[" .. tcp_config.host
.. "] port[" .. tostring(tcp_config.port) .. "] err: " .. err
end
sock:setkeepalive(config.keepalive * 1000)
return true
end
local function send_to_skywalking(log_message)
local err_msg
local res = true
core.log.info("sending a batch logs to ", config.skywalking.endpoint_addr)
local httpc = http.new()
httpc:set_timeout(config.timeout * 1000)
local entries = {}
local service_instance_name = config.skywalking.service_instance_name
if service_instance_name == "$hostname" then
service_instance_name = core.utils.gethostname()
end
for i = 1, #log_message, 2 do
local content = {
service = config.skywalking.service_name,
serviceInstance = service_instance_name,
endpoint = "",
body = {
text = {
text = log_message[i]
}
}
}
table.insert(entries, content)
end
local httpc_res, httpc_err = httpc:request_uri(
config.skywalking.endpoint_addr,
{
method = "POST",
body = core.json.encode(entries),
keepalive_timeout = config.keepalive * 1000,
headers = {
["Content-Type"] = "application/json",
}
}
)
if not httpc_res then
return false, "error while sending data to skywalking["
.. config.skywalking.endpoint_addr .. "] " .. httpc_err
end
-- some error occurred in the server
if httpc_res.status >= 400 then
res = false
err_msg = string.format(
"server returned status code[%s] skywalking[%s] body[%s]",
httpc_res.status,
config.skywalking.endpoint_addr.endpoint_addr,
httpc_res:read_body()
)
end
return res, err_msg
end
local function send_to_clickhouse(log_message)
local err_msg
local res = true
core.log.info("sending a batch logs to ", config.clickhouse.endpoint_addr)
local httpc = http.new()
httpc:set_timeout(config.timeout * 1000)
local entries = {}
for i = 1, #log_message, 2 do
-- TODO Here save error log as a whole string to clickhouse 'data' column.
-- We will add more columns in the future.
table.insert(entries, core.json.encode({data=log_message[i]}))
end
local httpc_res, httpc_err = httpc:request_uri(
config.clickhouse.endpoint_addr,
{
method = "POST",
body = "INSERT INTO " .. config.clickhouse.logtable .." FORMAT JSONEachRow "
.. table.concat(entries, " "),
keepalive_timeout = config.keepalive * 1000,
headers = {
["Content-Type"] = "application/json",
["X-ClickHouse-User"] = config.clickhouse.user,
["X-ClickHouse-Key"] = config.clickhouse.password,
["X-ClickHouse-Database"] = config.clickhouse.database
}
}
)
if not httpc_res then
return false, "error while sending data to clickhouse["
.. config.clickhouse.endpoint_addr .. "] " .. httpc_err
end
-- some error occurred in the server
if httpc_res.status >= 400 then
res = false
err_msg = string.format(
"server returned status code[%s] clickhouse[%s] body[%s]",
httpc_res.status,
config.clickhouse.endpoint_addr.endpoint_addr,
httpc_res:read_body()
)
end
return res, err_msg
end
local function update_filter(value)
local level = log_level[value.level]
local status, err = errlog.set_filter_level(level)
if not status then
return nil, "failed to set filter level by ngx.errlog, the error is :" .. err
else
core.log.notice("set the filter_level to ", value.level)
end
return value
end
local function create_producer(broker_list, broker_config, cluster_name)
core.log.info("create new kafka producer instance")
return producer:new(broker_list, broker_config, cluster_name)
end
local function send_to_kafka(log_message)
-- avoid race of the global config
local metadata = plugin.plugin_metadata(plugin_name)
if not (metadata and metadata.value and metadata.modifiedIndex) then
return false, "please set the correct plugin_metadata for " .. plugin_name
end
local config, err = lrucache(plugin_name, metadata.modifiedIndex, update_filter, metadata.value)
if not config then
return false, "get config failed: " .. err
end
core.log.info("sending a batch logs to kafka brokers: ",
core.json.delay_encode(config.kafka.brokers))
local broker_config = {}
broker_config["request_timeout"] = config.timeout * 1000
broker_config["producer_type"] = config.kafka.producer_type
broker_config["required_acks"] = config.kafka.required_acks
broker_config["refresh_interval"] = config.kafka.meta_refresh_interval * 1000
-- reuse producer via kafka_prod_lrucache to avoid unbalanced partitions of messages in kafka
local prod, err = kafka_prod_lrucache(plugin_name, metadata.modifiedIndex,
create_producer, config.kafka.brokers, broker_config,
config.kafka.cluster_name)
if not prod then
return false, "get kafka producer failed: " .. err
end
core.log.info("kafka cluster name ", config.kafka.cluster_name, ", broker_list[1] port ",
prod.client.broker_list[1].port)
local ok
for i = 1, #log_message, 2 do
ok, err = prod:send(config.kafka.kafka_topic,
config.kafka.key, core.json.encode(log_message[i]))
if not ok then
return false, "failed to send data to Kafka topic: " .. err ..
", brokers: " .. core.json.encode(config.kafka.brokers)
end
core.log.info("send data to kafka: ", core.json.delay_encode(log_message[i]))
end
return true
end
local function send(data)
if config.skywalking then
return send_to_skywalking(data)
elseif config.clickhouse then
return send_to_clickhouse(data)
elseif config.kafka then
return send_to_kafka(data)
end
return send_to_tcp_server(data)
end
local function process()
local metadata = plugin.plugin_metadata(plugin_name)
if not (metadata and metadata.value and metadata.modifiedIndex) then
core.log.info("please set the correct plugin_metadata for ", plugin_name)
return
else
local err
config, err = lrucache(plugin_name, metadata.modifiedIndex, update_filter, metadata.value)
if not config then
core.log.warn("set log filter failed for ", err)
return
end
if not (config.tcp or config.skywalking or config.clickhouse or config.kafka) then
config.tcp = {
host = config.host,
port = config.port,
tls = config.tls,
tls_server_name = config.tls_server_name
}
core.log.warn(
string.format("The schema is out of date. Please update to the new configuration, "
.. "for example: {\"tcp\": {\"host\": \"%s\", \"port\": \"%s\"}}",
config.host, config.port
))
end
end
local err_level = log_level[metadata.value.level]
local entries = {}
local logs = errlog.get_logs(9)
while ( logs and #logs>0 ) do
for i = 1, #logs, 3 do
-- There will be some stale error logs after the filter level changed.
-- We should avoid reporting them.
if logs[i] <= err_level then
table.insert(entries, logs[i + 2])
table.insert(entries, "\n")
end
end
logs = errlog.get_logs(9)
end
if #entries == 0 then
return
end
if log_buffer then
for _, v in ipairs(entries) do
log_buffer:push(v)
end
return
end
local config_bat = {
name = config.name,
retry_delay = config.retry_delay,
batch_max_size = config.batch_max_size,
max_retry_count = config.max_retry_count,
buffer_duration = config.buffer_duration,
inactive_timeout = config.inactive_timeout,
}
local err
log_buffer, err = batch_processor:new(send, config_bat)
if not log_buffer then
core.log.warn("error when creating the batch processor: ", err)
return
end
for _, v in ipairs(entries) do
log_buffer:push(v)
end
end
function _M.init()
timers.register_timer("plugin#error-log-logger", process)
end
function _M.destroy()
timers.unregister_timer("plugin#error-log-logger")
end
return _M

View File

@@ -0,0 +1,152 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local ngx = ngx
local core = require("apisix.core")
local plugin = require("apisix.plugin")
local upstream = require("apisix.upstream")
local schema = {
type = "object",
properties = {
i = {type = "number", minimum = 0},
s = {type = "string"},
t = {type = "array", minItems = 1},
ip = {type = "string"},
port = {type = "integer"},
},
required = {"i"},
}
local metadata_schema = {
type = "object",
properties = {
ikey = {type = "number", minimum = 0},
skey = {type = "string"},
},
required = {"ikey", "skey"},
}
local plugin_name = "example-plugin"
local _M = {
version = 0.1,
priority = 0,
name = plugin_name,
schema = schema,
metadata_schema = metadata_schema,
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
return core.schema.check(schema, conf)
end
function _M.init()
-- call this function when plugin is loaded
local attr = plugin.plugin_attr(plugin_name)
if attr then
core.log.info(plugin_name, " get plugin attr val: ", attr.val)
end
end
function _M.destroy()
-- call this function when plugin is unloaded
end
function _M.rewrite(conf, ctx)
core.log.warn("plugin rewrite phase, conf: ", core.json.encode(conf))
core.log.warn("conf_type: ", ctx.conf_type)
core.log.warn("conf_id: ", ctx.conf_id)
core.log.warn("conf_version: ", ctx.conf_version)
end
function _M.access(conf, ctx)
core.log.warn("plugin access phase, conf: ", core.json.encode(conf))
-- return 200, {message = "hit example plugin"}
if not conf.ip then
return
end
local up_conf = {
type = "roundrobin",
nodes = {
{host = conf.ip, port = conf.port, weight = 1}
}
}
local ok, err = upstream.check_schema(up_conf)
if not ok then
return 500, err
end
local matched_route = ctx.matched_route
upstream.set(ctx, up_conf.type .. "#route_" .. matched_route.value.id,
ctx.conf_version, up_conf)
return
end
function _M.header_filter(conf, ctx)
core.log.warn("plugin header_filter phase, conf: ", core.json.encode(conf))
end
function _M.body_filter(conf, ctx)
core.log.warn("plugin body_filter phase, eof: ", ngx.arg[2],
", conf: ", core.json.encode(conf))
end
function _M.delayed_body_filter(conf, ctx)
core.log.warn("plugin delayed_body_filter phase, eof: ", ngx.arg[2],
", conf: ", core.json.encode(conf))
end
function _M.log(conf, ctx)
core.log.warn("plugin log phase, conf: ", core.json.encode(conf))
end
local function hello()
local args = ngx.req.get_uri_args()
if args["json"] then
return 200, {msg = "world"}
else
return 200, "world\n"
end
end
function _M.control_api()
return {
{
methods = {"GET"},
uris = {"/v1/plugin/example-plugin/hello"},
handler = hello,
}
}
end
return _M

View File

@@ -0,0 +1,40 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ext = require("apisix.plugins.ext-plugin.init")
local name = "ext-plugin-post-req"
local _M = {
version = 0.1,
priority = -3000,
name = name,
schema = ext.schema,
}
function _M.check_schema(conf)
return core.schema.check(_M.schema, conf)
end
function _M.access(conf, ctx)
return ext.communicate(conf, ctx, name)
end
return _M

View File

@@ -0,0 +1,183 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ext = require("apisix.plugins.ext-plugin.init")
local helper = require("apisix.plugins.ext-plugin.helper")
local constants = require("apisix.constants")
local http = require("resty.http")
local ngx = ngx
local ngx_print = ngx.print
local ngx_flush = ngx.flush
local string = string
local str_sub = string.sub
local name = "ext-plugin-post-resp"
local _M = {
version = 0.1,
priority = -4000,
name = name,
schema = ext.schema,
}
local function include_req_headers(ctx)
-- TODO: handle proxy_set_header
return core.request.headers(ctx)
end
local function close(http_obj)
-- TODO: keepalive
local ok, err = http_obj:close()
if not ok then
core.log.error("close http object failed: ", err)
end
end
local function get_response(ctx, http_obj)
local ok, err = http_obj:connect({
scheme = ctx.upstream_scheme,
host = ctx.picked_server.host,
port = ctx.picked_server.port,
})
if not ok then
return nil, err
end
-- TODO: set timeout
local uri, args
if ctx.var.upstream_uri == "" then
-- use original uri instead of rewritten one
uri = ctx.var.uri
else
uri = ctx.var.upstream_uri
-- the rewritten one may contain new args
local index = core.string.find(uri, "?")
if index then
local raw_uri = uri
uri = str_sub(raw_uri, 1, index - 1)
args = str_sub(raw_uri, index + 1)
end
end
local params = {
path = uri,
query = args or ctx.var.args,
headers = include_req_headers(ctx),
method = core.request.get_method(),
}
local body, err = core.request.get_body()
if err then
return nil, err
end
if body then
params["body"] = body
end
local res, err = http_obj:request(params)
if not res then
return nil, err
end
return res, err
end
local function send_chunk(chunk)
if not chunk then
return nil
end
local ok, print_err = ngx_print(chunk)
if not ok then
return "output response failed: ".. (print_err or "")
end
local ok, flush_err = ngx_flush(true)
if not ok then
core.log.warn("flush response failed: ", flush_err)
end
return nil
end
-- TODO: response body is empty (304 or HEAD)
-- If the upstream returns 304 or the request method is HEAD,
-- there is no response body. In this case,
-- we need to send a response to the client in the plugin,
-- instead of continuing to execute the subsequent plugin.
local function send_response(ctx, res, code)
ngx.status = code or res.status
local chunks = ctx.runner_ext_response_body
if chunks then
for i=1, #chunks do
local err = send_chunk(chunks[i])
if err then
return err
end
end
return
end
return helper.response_reader(res.body_reader, send_chunk)
end
function _M.check_schema(conf)
return core.schema.check(_M.schema, conf)
end
function _M.before_proxy(conf, ctx)
local http_obj = http.new()
local res, err = get_response(ctx, http_obj)
if not res or err then
core.log.error("failed to request: ", err or "")
close(http_obj)
return 502
end
ctx.runner_ext_response = res
core.log.info("response info, status: ", res.status)
core.log.info("response info, headers: ", core.json.delay_encode(res.headers))
local code, body = ext.communicate(conf, ctx, name, constants.RPC_HTTP_RESP_CALL)
if body then
close(http_obj)
-- if the body is changed, the code will be set.
return code, body
end
core.log.info("ext-plugin will send response")
-- send origin response, status maybe changed.
err = send_response(ctx, res, code)
close(http_obj)
if err then
core.log.error(err)
return not ngx.headers_sent and 502 or nil
end
core.log.info("ext-plugin send response succefully")
end
return _M

View File

@@ -0,0 +1,40 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ext = require("apisix.plugins.ext-plugin.init")
local name = "ext-plugin-pre-req"
local _M = {
version = 0.1,
priority = 12000,
name = name,
schema = ext.schema,
}
function _M.check_schema(conf)
return core.schema.check(_M.schema, conf)
end
function _M.rewrite(conf, ctx)
return ext.communicate(conf, ctx, name)
end
return _M

View File

@@ -0,0 +1,81 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local is_http = ngx.config.subsystem == "http"
local core = require("apisix.core")
local config_local = require("apisix.core.config_local")
local process
if is_http then
process = require "ngx.process"
end
local pl_path = require("pl.path")
local _M = {}
do
local path
function _M.get_path()
if not path then
local local_conf = config_local.local_conf()
if local_conf then
local test_path =
core.table.try_read_attr(local_conf, "ext-plugin", "path_for_test")
if test_path then
path = "unix:" .. test_path
end
end
if not path then
local sock = "./conf/apisix-" .. process.get_master_pid() .. ".sock"
path = "unix:" .. pl_path.abspath(sock)
end
end
return path
end
end
function _M.get_conf_token_cache_time()
return 3600
end
function _M.response_reader(reader, callback, ...)
if not reader then
return "get response reader failed"
end
repeat
local chunk, read_err, cb_err
chunk, read_err = reader()
if read_err then
return "read response failed: ".. (read_err or "")
end
if chunk then
cb_err = callback(chunk, ...)
if cb_err then
return cb_err
end
end
until not chunk
end
return _M

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,175 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local expr = require("resty.expr.v1")
local sleep = core.sleep
local random = math.random
local ipairs = ipairs
local ngx = ngx
local pairs = pairs
local type = type
local plugin_name = "fault-injection"
local schema = {
type = "object",
properties = {
abort = {
type = "object",
properties = {
http_status = {type = "integer", minimum = 200},
body = {type = "string", minLength = 0},
headers = {
type = "object",
minProperties = 1,
patternProperties = {
["^[^:]+$"] = {
oneOf = {
{ type = "string" },
{ type = "number" }
}
}
}
},
percentage = {type = "integer", minimum = 0, maximum = 100},
vars = {
type = "array",
maxItems = 20,
items = {
type = "array",
},
}
},
required = {"http_status"},
},
delay = {
type = "object",
properties = {
duration = {type = "number", minimum = 0},
percentage = {type = "integer", minimum = 0, maximum = 100},
vars = {
type = "array",
maxItems = 20,
items = {
type = "array",
},
}
},
required = {"duration"},
}
},
minProperties = 1,
}
local _M = {
version = 0.1,
priority = 11000,
name = plugin_name,
schema = schema,
}
local function sample_hit(percentage)
if not percentage then
return true
end
return random(1, 100) <= percentage
end
local function vars_match(vars, ctx)
local match_result
for _, var in ipairs(vars) do
local expr, _ = expr.new(var)
match_result = expr:eval(ctx.var)
if match_result then
break
end
end
return match_result
end
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
if conf.abort and conf.abort.vars then
for _, var in ipairs(conf.abort.vars) do
local _, err = expr.new(var)
if err then
core.log.error("failed to create vars expression: ", err)
return false, err
end
end
end
if conf.delay and conf.delay.vars then
for _, var in ipairs(conf.delay.vars) do
local _, err = expr.new(var)
if err then
core.log.error("failed to create vars expression: ", err)
return false, err
end
end
end
return true
end
function _M.rewrite(conf, ctx)
core.log.info("plugin rewrite phase, conf: ", core.json.delay_encode(conf))
local abort_vars = true
if conf.abort and conf.abort.vars then
abort_vars = vars_match(conf.abort.vars, ctx)
end
core.log.info("abort_vars: ", abort_vars)
local delay_vars = true
if conf.delay and conf.delay.vars then
delay_vars = vars_match(conf.delay.vars, ctx)
end
core.log.info("delay_vars: ", delay_vars)
if conf.delay and sample_hit(conf.delay.percentage) and delay_vars then
sleep(conf.delay.duration)
end
if conf.abort and sample_hit(conf.abort.percentage) and abort_vars then
if conf.abort.headers then
for header_name, header_value in pairs(conf.abort.headers) do
if type(header_value) == "string" then
header_value = core.utils.resolve_var(header_value, ctx.var)
end
ngx.header[header_name] = header_value
end
end
return conf.abort.http_status, core.utils.resolve_var(conf.abort.body, ctx.var)
end
end
return _M

View File

@@ -0,0 +1,184 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local log_util = require("apisix.utils.log-util")
local core = require("apisix.core")
local expr = require("resty.expr.v1")
local ngx = ngx
local io_open = io.open
local is_apisix_or, process = pcall(require, "resty.apisix.process")
local plugin_name = "file-logger"
local schema = {
type = "object",
properties = {
path = {
type = "string"
},
log_format = {type = "object"},
include_req_body = {type = "boolean", default = false},
include_req_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
include_resp_body = {type = "boolean", default = false},
include_resp_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
match = {
type = "array",
maxItems = 20,
items = {
type = "array",
},
}
},
required = {"path"}
}
local metadata_schema = {
type = "object",
properties = {
log_format = {
type = "object"
}
}
}
local _M = {
version = 0.1,
priority = 399,
name = plugin_name,
schema = schema,
metadata_schema = metadata_schema
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
if conf.match then
local ok, err = expr.new(conf.match)
if not ok then
return nil, "failed to validate the 'match' expression: " .. err
end
end
return core.schema.check(schema, conf)
end
local open_file_cache
if is_apisix_or then
-- TODO: switch to a cache which supports inactive time,
-- so that unused files would not be cached
local path_to_file = core.lrucache.new({
type = "plugin",
})
local function open_file_handler(conf, handler)
local file, err = io_open(conf.path, 'a+')
if not file then
return nil, err
end
-- it will case output problem with buffer when log is larger than buffer
file:setvbuf("no")
handler.file = file
handler.open_time = ngx.now() * 1000
return handler
end
function open_file_cache(conf)
local last_reopen_time = process.get_last_reopen_ms()
local handler, err = path_to_file(conf.path, 0, open_file_handler, conf, {})
if not handler then
return nil, err
end
if handler.open_time < last_reopen_time then
core.log.notice("reopen cached log file: ", conf.path)
handler.file:close()
local ok, err = open_file_handler(conf, handler)
if not ok then
return nil, err
end
end
return handler.file
end
end
local function write_file_data(conf, log_message)
local msg = core.json.encode(log_message)
local file, err
if open_file_cache then
file, err = open_file_cache(conf)
else
file, err = io_open(conf.path, 'a+')
end
if not file then
core.log.error("failed to open file: ", conf.path, ", error info: ", err)
else
-- file:write(msg, "\n") will call fwrite several times
-- which will cause problem with the log output
-- it should be atomic
msg = msg .. "\n"
-- write to file directly, no need flush
local ok, err = file:write(msg)
if not ok then
core.log.error("failed to write file: ", conf.path, ", error info: ", err)
end
-- file will be closed by gc, if open_file_cache exists
if not open_file_cache then
file:close()
end
end
end
function _M.body_filter(conf, ctx)
log_util.collect_body(conf, ctx)
end
function _M.log(conf, ctx)
local entry = log_util.get_log_entry(plugin_name, conf, ctx)
if entry == nil then
return
end
write_file_data(conf, entry)
end
return _M

View File

@@ -0,0 +1,164 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local ipairs = ipairs
local core = require("apisix.core")
local http = require("resty.http")
local schema = {
type = "object",
properties = {
uri = {type = "string"},
allow_degradation = {type = "boolean", default = false},
status_on_error = {type = "integer", minimum = 200, maximum = 599, default = 403},
ssl_verify = {
type = "boolean",
default = true,
},
request_method = {
type = "string",
default = "GET",
enum = {"GET", "POST"},
description = "the method for client to request the authorization service"
},
request_headers = {
type = "array",
default = {},
items = {type = "string"},
description = "client request header that will be sent to the authorization service"
},
upstream_headers = {
type = "array",
default = {},
items = {type = "string"},
description = "authorization response header that will be sent to the upstream"
},
client_headers = {
type = "array",
default = {},
items = {type = "string"},
description = "authorization response header that will be sent to"
.. "the client when authorizing failed"
},
timeout = {
type = "integer",
minimum = 1,
maximum = 60000,
default = 3000,
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
keepalive_timeout = {type = "integer", minimum = 1000, default = 60000},
keepalive_pool = {type = "integer", minimum = 1, default = 5},
},
required = {"uri"}
}
local _M = {
version = 0.1,
priority = 2002,
name = "forward-auth",
schema = schema,
}
function _M.check_schema(conf)
local check = {"uri"}
core.utils.check_https(check, conf, _M.name)
core.utils.check_tls_bool({"ssl_verify"}, conf, _M.name)
return core.schema.check(schema, conf)
end
function _M.access(conf, ctx)
local auth_headers = {
["X-Forwarded-Proto"] = core.request.get_scheme(ctx),
["X-Forwarded-Method"] = core.request.get_method(),
["X-Forwarded-Host"] = core.request.get_host(ctx),
["X-Forwarded-Uri"] = ctx.var.request_uri,
["X-Forwarded-For"] = core.request.get_remote_client_ip(ctx),
}
if conf.request_method == "POST" then
auth_headers["Content-Length"] = core.request.header(ctx, "content-length")
auth_headers["Expect"] = core.request.header(ctx, "expect")
auth_headers["Transfer-Encoding"] = core.request.header(ctx, "transfer-encoding")
auth_headers["Content-Encoding"] = core.request.header(ctx, "content-encoding")
end
-- append headers that need to be get from the client request header
if #conf.request_headers > 0 then
for _, header in ipairs(conf.request_headers) do
if not auth_headers[header] then
auth_headers[header] = core.request.header(ctx, header)
end
end
end
local params = {
headers = auth_headers,
keepalive = conf.keepalive,
ssl_verify = conf.ssl_verify,
method = conf.request_method
}
if params.method == "POST" then
params.body = core.request.get_body()
end
if conf.keepalive then
params.keepalive_timeout = conf.keepalive_timeout
params.keepalive_pool = conf.keepalive_pool
end
local httpc = http.new()
httpc:set_timeout(conf.timeout)
local res, err = httpc:request_uri(conf.uri, params)
if not res and conf.allow_degradation then
return
elseif not res then
core.log.warn("failed to process forward auth, err: ", err)
return conf.status_on_error
end
if res.status >= 300 then
local client_headers = {}
if #conf.client_headers > 0 then
for _, header in ipairs(conf.client_headers) do
client_headers[header] = res.headers[header]
end
end
core.response.set_header(client_headers)
return res.status, res.body
end
-- append headers that need to be get from the auth response header
for _, header in ipairs(conf.upstream_headers) do
local header_value = res.headers[header]
if header_value then
core.request.set_header(ctx, header, header_value)
end
end
end
return _M

View File

@@ -0,0 +1,175 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.
-- local common libs
local require = require
local pcall = pcall
local ffi = require("ffi")
local C = ffi.C
local get_request = require("resty.core.base").get_request
local core = require("apisix.core")
local radixtree_sni = require("apisix.ssl.router.radixtree_sni")
local apisix_ssl = require("apisix.ssl")
local _, ssl = pcall(require, "resty.apisix.ssl")
local error = error
ffi.cdef[[
unsigned long Tongsuo_version_num(void)
]]
-- local function
local function set_pem_ssl_key(sni, enc_cert, enc_pkey, sign_cert, sign_pkey)
local r = get_request()
if r == nil then
return false, "no request found"
end
local parsed_enc_cert, err = apisix_ssl.fetch_cert(sni, enc_cert)
if not parsed_enc_cert then
return false, "failed to parse enc PEM cert: " .. err
end
local parsed_sign_cert, err = apisix_ssl.fetch_cert(sni, sign_cert)
if not parsed_sign_cert then
return false, "failed to parse sign PEM cert: " .. err
end
local ok, err = ssl.set_gm_cert(parsed_enc_cert, parsed_sign_cert)
if not ok then
return false, "failed to set PEM cert: " .. err
end
local parsed_enc_pkey, err = apisix_ssl.fetch_pkey(sni, enc_pkey)
if not parsed_enc_pkey then
return false, "failed to parse enc PEM priv key: " .. err
end
local parsed_sign_pkey, err = apisix_ssl.fetch_pkey(sni, sign_pkey)
if not parsed_sign_pkey then
return false, "failed to parse sign PEM priv key: " .. err
end
ok, err = ssl.set_gm_priv_key(parsed_enc_pkey, parsed_sign_pkey)
if not ok then
return false, "failed to set PEM priv key: " .. err
end
return true
end
local original_set_cert_and_key
local function set_cert_and_key(sni, value)
if value.gm then
-- process as GM certificate
-- For GM dual certificate, the `cert` and `key` will be encryption cert/key.
-- The first item in `certs` and `keys` will be sign cert/key.
local enc_cert = value.cert
local enc_pkey = value.key
local sign_cert = value.certs[1]
local sign_pkey = value.keys[1]
return set_pem_ssl_key(sni, enc_cert, enc_pkey, sign_cert, sign_pkey)
end
return original_set_cert_and_key(sni, value)
end
local original_check_ssl_conf
local function check_ssl_conf(in_dp, conf)
if conf.gm then
-- process as GM certificate
-- For GM dual certificate, the `cert` and `key` will be encryption cert/key.
-- The first item in `certs` and `keys` will be sign cert/key.
local ok, err = original_check_ssl_conf(in_dp, conf)
-- check cert/key first in the original method
if not ok then
return nil, err
end
-- Currently, APISIX doesn't check the cert type (ECDSA / RSA). So we skip the
-- check for now in this plugin.
local num_certs = conf.certs and #conf.certs or 0
local num_keys = conf.keys and #conf.keys or 0
if num_certs ~= 1 or num_keys ~= 1 then
return nil, "sign cert/key are required"
end
return true
end
return original_check_ssl_conf(in_dp, conf)
end
-- module define
local plugin_name = "gm"
-- plugin schema
local plugin_schema = {
type = "object",
properties = {
},
}
local _M = {
version = 0.1, -- plugin version
priority = -43,
name = plugin_name, -- plugin name
schema = plugin_schema, -- plugin schema
}
function _M.init()
if not pcall(function () return C.Tongsuo_version_num end) then
error("need to build Tongsuo (https://github.com/Tongsuo-Project/Tongsuo) " ..
"into the APISIX-Runtime")
end
ssl.enable_ntls()
original_set_cert_and_key = radixtree_sni.set_cert_and_key
radixtree_sni.set_cert_and_key = set_cert_and_key
original_check_ssl_conf = apisix_ssl.check_ssl_conf
apisix_ssl.check_ssl_conf = check_ssl_conf
if core.schema.ssl.properties.gm ~= nil then
error("Field 'gm' is occupied")
end
-- inject a mark to distinguish GM certificate
core.schema.ssl.properties.gm = {
type = "boolean"
}
end
function _M.destroy()
ssl.disable_ntls()
radixtree_sni.set_cert_and_key = original_set_cert_and_key
apisix_ssl.check_ssl_conf = original_check_ssl_conf
core.schema.ssl.properties.gm = nil
end
-- module interface for schema check
-- @param `conf` user defined conf data
-- @param `schema_type` defined in `apisix/core/schema.lua`
-- @return <boolean>
function _M.check_schema(conf, schema_type)
return core.schema.check(plugin_schema, conf)
end
return _M

View File

@@ -0,0 +1,265 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local tostring = tostring
local http = require("resty.http")
local log_util = require("apisix.utils.log-util")
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local google_oauth = require("apisix.utils.google-cloud-oauth")
local lrucache = core.lrucache.new({
type = "plugin",
})
local plugin_name = "google-cloud-logging"
local batch_processor_manager = bp_manager_mod.new(plugin_name)
local schema = {
type = "object",
properties = {
auth_config = {
type = "object",
properties = {
client_email = { type = "string" },
private_key = { type = "string" },
project_id = { type = "string" },
token_uri = {
type = "string",
default = "https://oauth2.googleapis.com/token"
},
-- https://developers.google.com/identity/protocols/oauth2/scopes#logging
scope = {
type = "array",
items = {
description = "Google OAuth2 Authorization Scopes",
type = "string",
},
minItems = 1,
uniqueItems = true,
default = {
"https://www.googleapis.com/auth/logging.read",
"https://www.googleapis.com/auth/logging.write",
"https://www.googleapis.com/auth/logging.admin",
"https://www.googleapis.com/auth/cloud-platform"
}
},
scopes = {
type = "array",
items = {
description = "Google OAuth2 Authorization Scopes",
type = "string",
},
minItems = 1,
uniqueItems = true
},
entries_uri = {
type = "string",
default = "https://logging.googleapis.com/v2/entries:write"
},
},
required = { "client_email", "private_key", "project_id", "token_uri" }
},
ssl_verify = {
type = "boolean",
default = true
},
auth_file = { type = "string" },
-- https://cloud.google.com/logging/docs/reference/v2/rest/v2/MonitoredResource
resource = {
type = "object",
properties = {
type = { type = "string" },
labels = { type = "object" }
},
default = {
type = "global"
},
required = { "type" }
},
-- https://cloud.google.com/logging/docs/reference/v2/rest/v2/LogEntry
log_id = {
type = "string",
default = "apisix.apache.org%2Flogs"
},
log_format = {type = "object"},
},
oneOf = {
{ required = { "auth_config" } },
{ required = { "auth_file" } },
},
encrypt_fields = {"auth_config.private_key"},
}
local metadata_schema = {
type = "object",
properties = {
log_format = {
type = "object"
}
},
}
local function send_to_google(oauth, entries)
local http_new = http.new()
local access_token = oauth:generate_access_token()
if not access_token then
return nil, "failed to get google oauth token"
end
local res, err = http_new:request_uri(oauth.entries_uri, {
ssl_verify = oauth.ssl_verify,
method = "POST",
body = core.json.encode({
entries = entries,
partialSuccess = false,
}),
headers = {
["Content-Type"] = "application/json",
["Authorization"] = (oauth.access_token_type or "Bearer") .. " " .. access_token,
},
})
if not res then
return nil, "failed to write log to google, " .. err
end
if res.status ~= 200 then
return nil, res.body
end
return res.body
end
local function fetch_oauth_conf(conf)
if conf.auth_config then
return conf.auth_config
end
if not conf.auth_file then
return nil, "configuration is not defined"
end
local file_content, err = core.io.get_file(conf.auth_file)
if not file_content then
return nil, "failed to read configuration, file: " .. conf.auth_file .. " err: " .. err
end
local config_tab
config_tab, err = core.json.decode(file_content)
if not config_tab then
return nil, "config parse failure, data: " .. file_content .. " , err: " .. err
end
return config_tab
end
local function create_oauth_object(conf)
local auth_conf, err = fetch_oauth_conf(conf)
if not auth_conf then
return nil, err
end
auth_conf.scope = auth_conf.scopes or auth_conf.scope
return google_oauth.new(auth_conf, conf.ssl_verify)
end
local function get_logger_entry(conf, ctx, oauth)
local entry, customized = log_util.get_log_entry(plugin_name, conf, ctx)
local google_entry
if not customized then
google_entry = {
httpRequest = {
requestMethod = entry.request.method,
requestUrl = entry.request.url,
requestSize = entry.request.size,
status = entry.response.status,
responseSize = entry.response.size,
userAgent = entry.request.headers and entry.request.headers["user-agent"],
remoteIp = entry.client_ip,
serverIp = entry.upstream,
latency = tostring(core.string.format("%0.3f", entry.latency / 1000)) .. "s"
},
jsonPayload = {
route_id = entry.route_id,
service_id = entry.service_id,
},
}
else
google_entry = {
jsonPayload = entry,
}
end
google_entry.labels = {
source = "apache-apisix-google-cloud-logging"
}
google_entry.timestamp = log_util.get_rfc3339_zulu_timestamp()
google_entry.resource = conf.resource
google_entry.insertId = ctx.var.request_id
google_entry.logName = core.string.format("projects/%s/logs/%s", oauth.project_id, conf.log_id)
return google_entry
end
local _M = {
version = 0.1,
priority = 407,
name = plugin_name,
metadata_schema = metadata_schema,
schema = batch_processor_manager:wrap_schema(schema),
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
return core.schema.check(schema, conf)
end
function _M.log(conf, ctx)
local oauth, err = core.lrucache.plugin_ctx(lrucache, ctx, nil,
create_oauth_object, conf)
if not oauth then
core.log.error("failed to fetch google-cloud-logging.oauth object: ", err)
return
end
local entry = get_logger_entry(conf, ctx, oauth)
if batch_processor_manager:add_entry(conf, entry) then
return
end
local process = function(entries)
return send_to_google(oauth, entries)
end
batch_processor_manager:add_entry_to_new_processor(conf, entry, ctx, process)
end
return _M

View File

@@ -0,0 +1,211 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local ngx = ngx
local core = require("apisix.core")
local schema_def = require("apisix.schema_def")
local proto = require("apisix.plugins.grpc-transcode.proto")
local request = require("apisix.plugins.grpc-transcode.request")
local response = require("apisix.plugins.grpc-transcode.response")
local plugin_name = "grpc-transcode"
local pb_option_def = {
{ description = "enum as result",
type = "string",
enum = {"enum_as_name", "enum_as_value"},
},
{ description = "int64 as result",
type = "string",
enum = {"int64_as_number", "int64_as_string", "int64_as_hexstring"},
},
{ description ="default values option",
type = "string",
enum = {"auto_default_values", "no_default_values",
"use_default_values", "use_default_metatable"},
},
{ description = "hooks option",
type = "string",
enum = {"enable_hooks", "disable_hooks" },
},
}
local schema = {
type = "object",
properties = {
proto_id = schema_def.id_schema,
service = {
description = "the grpc service name",
type = "string"
},
method = {
description = "the method name in the grpc service.",
type = "string"
},
deadline = {
description = "deadline for grpc, millisecond",
type = "number",
default = 0
},
pb_option = {
type = "array",
items = { type="string", anyOf = pb_option_def },
minItems = 1,
default = {
"enum_as_name",
"int64_as_number",
"auto_default_values",
"disable_hooks",
}
},
show_status_in_body = {
description = "show decoded grpc-status-details-bin in response body",
type = "boolean",
default = false
},
-- https://github.com/googleapis/googleapis/blob/b7cb84f5d42e6dba0fdcc2d8689313f6a8c9d7b9/
-- google/rpc/status.proto#L46
status_detail_type = {
description = "the message type of the grpc-status-details-bin's details part, "
.. "if not given, the details part will not be decoded",
type = "string",
},
},
additionalProperties = true,
required = { "proto_id", "service", "method" },
}
-- Based on https://cloud.google.com/apis/design/errors#handling_errors
local status_rel = {
["1"] = 499, -- CANCELLED
["2"] = 500, -- UNKNOWN
["3"] = 400, -- INVALID_ARGUMENT
["4"] = 504, -- DEADLINE_EXCEEDED
["5"] = 404, -- NOT_FOUND
["6"] = 409, -- ALREADY_EXISTS
["7"] = 403, -- PERMISSION_DENIED
["8"] = 429, -- RESOURCE_EXHAUSTED
["9"] = 400, -- FAILED_PRECONDITION
["10"] = 409, -- ABORTED
["11"] = 400, -- OUT_OF_RANGE
["12"] = 501, -- UNIMPLEMENTED
["13"] = 500, -- INTERNAL
["14"] = 503, -- UNAVAILABLE
["15"] = 500, -- DATA_LOSS
["16"] = 401, -- UNAUTHENTICATED
}
local _M = {
version = 0.1,
priority = 506,
name = plugin_name,
schema = schema,
}
function _M.init()
proto.init()
end
function _M.destroy()
proto.destroy()
end
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
return true
end
function _M.access(conf, ctx)
core.log.info("conf: ", core.json.delay_encode(conf))
local proto_id = conf.proto_id
if not proto_id then
core.log.error("proto id miss: ", proto_id)
return
end
local proto_obj, err = proto.fetch(proto_id)
if err then
core.log.error("proto load error: ", err)
return
end
local ok, err, err_code = request(proto_obj, conf.service,
conf.method, conf.pb_option, conf.deadline)
if not ok then
core.log.error("transform request error: ", err)
return err_code
end
ctx.proto_obj = proto_obj
end
function _M.header_filter(conf, ctx)
if ngx.status >= 300 then
return
end
ngx.header["Content-Type"] = "application/json"
ngx.header.content_length = nil
local headers = ngx.resp.get_headers()
if headers["grpc-status"] ~= nil and headers["grpc-status"] ~= "0" then
local http_status = status_rel[headers["grpc-status"]]
if http_status ~= nil then
ngx.status = http_status
else
ngx.status = 599
end
else
-- The error response body does not contain grpc-status and grpc-message
ngx.header["Trailer"] = {"grpc-status", "grpc-message"}
end
end
function _M.body_filter(conf, ctx)
if ngx.status >= 300 and not conf.show_status_in_body then
return
end
local proto_obj = ctx.proto_obj
if not proto_obj then
return
end
local err = response(ctx, proto_obj, conf.service, conf.method, conf.pb_option,
conf.show_status_in_body, conf.status_detail_type)
if err then
core.log.error("transform response error: ", err)
return
end
end
return _M

View File

@@ -0,0 +1,279 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local config_util = require("apisix.core.config_util")
local pb = require("pb")
local protoc = require("protoc")
local pcall = pcall
local ipairs = ipairs
local decode_base64 = ngx.decode_base64
local protos
local lrucache_proto = core.lrucache.new({
ttl = 300, count = 100
})
local proto_fake_file = "filename for loaded"
local function compile_proto_text(content)
protoc.reload()
local _p = protoc.new()
-- the loaded proto won't appears in _p.loaded without a file name after lua-protobuf=0.3.2,
-- which means _p.loaded after _p:load(content) is always empty, so we can pass a fake file
-- name to keep the code below unchanged, or we can create our own load function with returning
-- the loaded DescriptorProto table additionally, see more details in
-- https://github.com/apache/apisix/pull/4368
local ok, res = pcall(_p.load, _p, content, proto_fake_file)
if not ok then
return nil, res
end
if not res or not _p.loaded then
return nil, "failed to load proto content"
end
local compiled = _p.loaded
local index = {}
for _, s in ipairs(compiled[proto_fake_file].service or {}) do
local method_index = {}
for _, m in ipairs(s.method) do
method_index[m.name] = m
end
index[compiled[proto_fake_file].package .. '.' .. s.name] = method_index
end
compiled[proto_fake_file].index = index
return compiled
end
local function compile_proto_bin(content)
content = decode_base64(content)
if not content then
return nil
end
-- pb.load doesn't return err
local ok = pb.load(content)
if not ok then
return nil
end
local files = pb.decode("google.protobuf.FileDescriptorSet", content).file
local index = {}
for _, f in ipairs(files) do
for _, s in ipairs(f.service or {}) do
local method_index = {}
for _, m in ipairs(s.method) do
method_index[m.name] = m
end
index[f.package .. '.' .. s.name] = method_index
end
end
local compiled = {}
compiled[proto_fake_file] = {}
compiled[proto_fake_file].index = index
return compiled
end
local function compile_proto(content)
-- clear pb state
local old_pb_state = pb.state(nil)
local compiled, err = compile_proto_text(content)
if not compiled then
compiled = compile_proto_bin(content)
if not compiled then
return nil, err
end
end
-- fetch pb state
compiled.pb_state = pb.state(old_pb_state)
return compiled
end
local _M = {
version = 0.1,
compile_proto = compile_proto,
proto_fake_file = proto_fake_file
}
local function create_proto_obj(proto_id)
if protos.values == nil then
return nil
end
local content
for _, proto in config_util.iterate_values(protos.values) do
if proto_id == proto.value.id then
content = proto.value.content
break
end
end
if not content then
return nil, "failed to find proto by id: " .. proto_id
end
return compile_proto(content)
end
function _M.fetch(proto_id)
return lrucache_proto(proto_id, protos.conf_version,
create_proto_obj, proto_id)
end
function _M.protos()
if not protos then
return nil, nil
end
return protos.values, protos.conf_version
end
local grpc_status_proto = [[
syntax = "proto3";
package grpc_status;
message Any {
// A URL/resource name that uniquely identifies the type of the serialized
// protocol buffer message. This string must contain at least
// one "/" character. The last segment of the URL's path must represent
// the fully qualified name of the type (as in
// `path/google.protobuf.Duration`). The name should be in a canonical form
// (e.g., leading "." is not accepted).
//
// In practice, teams usually precompile into the binary all types that they
// expect it to use in the context of Any. However, for URLs which use the
// scheme `http`, `https`, or no scheme, one can optionally set up a type
// server that maps type URLs to message definitions as follows:
//
// * If no scheme is provided, `https` is assumed.
// * An HTTP GET on the URL must yield a [google.protobuf.Type][]
// value in binary format, or produce an error.
// * Applications are allowed to cache lookup results based on the
// URL, or have them precompiled into a binary to avoid any
// lookup. Therefore, binary compatibility needs to be preserved
// on changes to types. (Use versioned type names to manage
// breaking changes.)
//
// Note: this functionality is not currently available in the official
// protobuf release, and it is not used for type URLs beginning with
// type.googleapis.com.
//
// Schemes other than `http`, `https` (or the empty scheme) might be
// used with implementation specific semantics.
//
string type_url = 1;
// Must be a valid serialized protocol buffer of the above specified type.
bytes value = 2;
}
// The `Status` type defines a logical error model that is suitable for
// different programming environments, including REST APIs and RPC APIs. It is
// used by [gRPC](https://github.com/grpc). Each `Status` message contains
// three pieces of data: error code, error message, and error details.
//
// You can find out more about this error model and how to work with it in the
// [API Design Guide](https://cloud.google.com/apis/design/errors).
message ErrorStatus {
// The status code, which should be an enum value of [google.rpc.Code][google.rpc.Code].
int32 code = 1;
// A developer-facing error message, which should be in English. Any
// user-facing error message should be localized and sent in the
// [google.rpc.Status.details][google.rpc.Status.details] field, or localized by the client.
string message = 2;
// A list of messages that carry the error details. There is a common set of
// message types for APIs to use.
repeated Any details = 3;
}
]]
local status_pb_state
local function init_status_pb_state()
if not status_pb_state then
-- clear current pb state
local old_pb_state = pb.state(nil)
-- initialize protoc compiler
protoc.reload()
local status_protoc = protoc.new()
-- do not use loadfile here, it can not load the proto file when using a relative address
-- after luarocks install apisix
local ok, err = status_protoc:load(grpc_status_proto, "grpc_status.proto")
if not ok then
status_protoc:reset()
pb.state(old_pb_state)
return "failed to load grpc status protocol: " .. err
end
status_pb_state = pb.state(old_pb_state)
end
end
function _M.fetch_status_pb_state()
return status_pb_state
end
function _M.init()
local err
protos, err = core.config.new("/protos", {
automatic = true,
item_schema = core.schema.proto
})
if not protos then
core.log.error("failed to create etcd instance for fetching protos: ",
err)
return
end
if not status_pb_state then
err = init_status_pb_state()
if err then
core.log.error("failed to init grpc status proto: ",
err)
return
end
end
end
function _M.destroy()
if protos then
protos:close()
end
end
return _M

View File

@@ -0,0 +1,72 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local util = require("apisix.plugins.grpc-transcode.util")
local core = require("apisix.core")
local pb = require("pb")
local bit = require("bit")
local ngx = ngx
local string = string
local table = table
local pcall = pcall
local tonumber = tonumber
local req_read_body = ngx.req.read_body
return function (proto, service, method, pb_option, deadline, default_values)
core.log.info("proto: ", core.json.delay_encode(proto, true))
local m = util.find_method(proto, service, method)
if not m then
return false, "Undefined service method: " .. service .. "/" .. method
.. " end", 503
end
req_read_body()
local pb_old_state = pb.state(proto.pb_state)
util.set_options(proto, pb_option)
local map_message = util.map_message(m.input_type, default_values or {})
local ok, encoded = pcall(pb.encode, m.input_type, map_message)
pb.state(pb_old_state)
if not ok or not encoded then
return false, "failed to encode request data to protobuf", 400
end
local size = #encoded
local prefix = {
string.char(0),
string.char(bit.band(bit.rshift(size, 24), 0xFF)),
string.char(bit.band(bit.rshift(size, 16), 0xFF)),
string.char(bit.band(bit.rshift(size, 8), 0xFF)),
string.char(bit.band(size, 0xFF))
}
local message = table.concat(prefix, "") .. encoded
ngx.req.set_method(ngx.HTTP_POST)
ngx.req.set_uri("/" .. service .. "/" .. method, false)
ngx.req.set_uri_args({})
ngx.req.set_body_data(message)
local dl = tonumber(deadline)
if dl~= nil and dl > 0 then
ngx.req.set_header("grpc-timeout", dl .. "m")
end
return true
end

View File

@@ -0,0 +1,144 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local util = require("apisix.plugins.grpc-transcode.util")
local grpc_proto = require("apisix.plugins.grpc-transcode.proto")
local core = require("apisix.core")
local pb = require("pb")
local ngx = ngx
local string = string
local ngx_decode_base64 = ngx.decode_base64
local ipairs = ipairs
local pcall = pcall
local function handle_error_response(status_detail_type, proto)
local err_msg
local grpc_status = ngx.header["grpc-status-details-bin"]
if grpc_status then
grpc_status = ngx_decode_base64(grpc_status)
if grpc_status == nil then
err_msg = "grpc-status-details-bin is not base64 format"
ngx.arg[1] = err_msg
return err_msg
end
local status_pb_state = grpc_proto.fetch_status_pb_state()
local old_pb_state = pb.state(status_pb_state)
local ok, decoded_grpc_status = pcall(pb.decode, "grpc_status.ErrorStatus", grpc_status)
pb.state(old_pb_state)
if not ok then
err_msg = "failed to call pb.decode to decode grpc-status-details-bin"
ngx.arg[1] = err_msg
return err_msg .. ", err: " .. decoded_grpc_status
end
if not decoded_grpc_status then
err_msg = "failed to decode grpc-status-details-bin"
ngx.arg[1] = err_msg
return err_msg
end
local details = decoded_grpc_status.details
if status_detail_type and details then
local decoded_details = {}
for _, detail in ipairs(details) do
local pb_old_state = pb.state(proto.pb_state)
local ok, err_or_value = pcall(pb.decode, status_detail_type, detail.value)
pb.state(pb_old_state)
if not ok then
err_msg = "failed to call pb.decode to decode details in "
.. "grpc-status-details-bin"
ngx.arg[1] = err_msg
return err_msg .. ", err: " .. err_or_value
end
if not err_or_value then
err_msg = "failed to decode details in grpc-status-details-bin"
ngx.arg[1] = err_msg
return err_msg
end
core.table.insert(decoded_details, err_or_value)
end
decoded_grpc_status.details = decoded_details
end
local resp_body = {error = decoded_grpc_status}
local response, err = core.json.encode(resp_body)
if not response then
err_msg = "failed to json_encode response body"
ngx.arg[1] = err_msg
return err_msg .. ", error: " .. err
end
ngx.arg[1] = response
end
end
return function(ctx, proto, service, method, pb_option, show_status_in_body, status_detail_type)
local buffer = core.response.hold_body_chunk(ctx)
if not buffer then
return nil
end
-- handle error response after the last response chunk
if ngx.status >= 300 and show_status_in_body then
return handle_error_response(status_detail_type, proto)
end
-- when body has already been read by other plugin
-- the buffer is an empty string
if buffer == "" and ctx.resp_body then
buffer = ctx.resp_body
end
local m = util.find_method(proto, service, method)
if not m then
return false, "2.Undefined service method: " .. service .. "/" .. method
.. " end."
end
if not ngx.req.get_headers()["X-Grpc-Web"] then
buffer = string.sub(buffer, 6)
end
local pb_old_state = pb.state(proto.pb_state)
util.set_options(proto, pb_option)
local err_msg
local decoded = pb.decode(m.output_type, buffer)
pb.state(pb_old_state)
if not decoded then
err_msg = "failed to decode response data by protobuf"
ngx.arg[1] = err_msg
return err_msg
end
local response, err = core.json.encode(decoded)
if not response then
err_msg = "failed to json_encode response body"
ngx.arg[1] = err_msg
return err_msg .. ", err: " .. err
end
ngx.arg[1] = response
return nil
end

View File

@@ -0,0 +1,202 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local proto_fake_file = require("apisix.plugins.grpc-transcode.proto").proto_fake_file
local json = core.json
local pb = require("pb")
local ngx = ngx
local string = string
local table = table
local ipairs = ipairs
local pairs = pairs
local tonumber = tonumber
local type = type
local _M = {version = 0.1}
function _M.find_method(proto, service, method)
local loaded = proto[proto_fake_file]
if type(loaded) ~= "table" then
core.log.error("compiled proto not found")
return nil
end
if type(loaded.index[service]) ~= "table" then
core.log.error("compiled proto service not found")
return nil
end
local res = loaded.index[service][method]
if not res then
core.log.error("compiled proto method not found")
return nil
end
return res
end
function _M.set_options(proto, options)
local cur_opts = proto.options
if cur_opts then
if cur_opts == options then
-- same route
return
end
local same = true
table.sort(options)
for i, v in ipairs(options) do
if cur_opts[i] ~= v then
same = false
break
end
end
if same then
-- Routes have the same configuration, usually the default one.
-- As this is a small optimization, we don't care about routes have different
-- configuration but have the same effect eventually.
return
end
else
table.sort(options)
end
for _, opt in ipairs(options) do
pb.option(opt)
end
proto.options = options
end
local function get_request_table()
local method = ngx.req.get_method()
local content_type = ngx.req.get_headers()["Content-Type"] or ""
if string.find(content_type, "application/json", 1, true) and
(method == "POST" or method == "PUT" or method == "PATCH")
then
local req_body, _ = core.request.get_body()
if req_body then
local data, _ = json.decode(req_body)
if data then
return data
end
end
end
if method == "POST" then
return ngx.req.get_post_args()
end
return ngx.req.get_uri_args()
end
local function get_from_request(request_table, name, kind)
if not request_table then
return nil
end
local prefix = kind:sub(1, 3)
if prefix == "int" then
if request_table[name] then
if kind == "int64" then
return request_table[name]
else
return tonumber(request_table[name])
end
end
end
return request_table[name]
end
function _M.map_message(field, default_values, request_table, real_key)
if not pb.type(field) then
return nil, "Field " .. field .. " is not defined"
end
local request = {}
local sub, err
if not request_table then
request_table = get_request_table()
end
for name, _, field_type in pb.fields(field) do
local _, _, ty = pb.type(field_type)
if ty ~= "enum" and field_type:sub(1, 1) == "." then
if request_table[name] == nil then
sub = default_values and default_values[name]
elseif core.table.isarray(request_table[name]) then
local sub_array = core.table.new(#request_table[name], 0)
for i, value in ipairs(request_table[name]) do
local sub_array_obj
if type(value) == "table" then
sub_array_obj, err = _M.map_message(field_type,
default_values and default_values[name], value)
if err then
return nil, err
end
else
sub_array_obj = value
end
sub_array[i] = sub_array_obj
end
sub = sub_array
else
if ty == "map" then
for k, v in pairs(request_table[name]) do
local tbl, err = _M.map_message(field_type,
default_values and default_values[name],
request_table[name], k)
if err then
return nil, err
end
if not sub then
sub = {}
end
sub[k] = tbl[k]
end
else
sub, err = _M.map_message(field_type,
default_values and default_values[name],
request_table[name])
if err then
return nil, err
end
end
end
request[name] = sub
else
if real_key then
name = real_key
end
request[name] = get_from_request(request_table, name, field_type)
or (default_values and default_values[name])
end
end
return request
end
return _M

View File

@@ -0,0 +1,228 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
local ngx = ngx
local ngx_arg = ngx.arg
local core = require("apisix.core")
local req_set_uri = ngx.req.set_uri
local req_set_body_data = ngx.req.set_body_data
local decode_base64 = ngx.decode_base64
local encode_base64 = ngx.encode_base64
local bit = require("bit")
local string = string
local ALLOW_METHOD_OPTIONS = "OPTIONS"
local ALLOW_METHOD_POST = "POST"
local CONTENT_ENCODING_BASE64 = "base64"
local CONTENT_ENCODING_BINARY = "binary"
local DEFAULT_CORS_ALLOW_ORIGIN = "*"
local DEFAULT_CORS_ALLOW_METHODS = ALLOW_METHOD_POST
local DEFAULT_CORS_ALLOW_HEADERS = "content-type,x-grpc-web,x-user-agent"
local DEFAULT_CORS_EXPOSE_HEADERS = "grpc-message,grpc-status"
local DEFAULT_PROXY_CONTENT_TYPE = "application/grpc"
local plugin_name = "grpc-web"
local schema = {
type = "object",
properties = {
cors_allow_headers = {
description =
"multiple header use ',' to split. default: content-type,x-grpc-web,x-user-agent.",
type = "string",
default = DEFAULT_CORS_ALLOW_HEADERS
}
}
}
local grpc_web_content_encoding = {
["application/grpc-web"] = CONTENT_ENCODING_BINARY,
["application/grpc-web-text"] = CONTENT_ENCODING_BASE64,
["application/grpc-web+proto"] = CONTENT_ENCODING_BINARY,
["application/grpc-web-text+proto"] = CONTENT_ENCODING_BASE64,
}
local _M = {
version = 0.1,
priority = 505,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
local function exit(ctx, status)
ctx.grpc_web_skip_body_filter = true
return status
end
--- Build gRPC-Web trailer chunk
-- grpc-web trailer format reference:
-- envoyproxy/envoy/source/extensions/filters/http/grpc_web/grpc_web_filter.cc
--
-- Format for grpc-web trailer
-- 1 byte: 0x80
-- 4 bytes: length of the trailer
-- n bytes: trailer
-- It using upstream_trailer_* variables from nginx, it is available since NGINX version 1.13.10
-- https://nginx.org/en/docs/http/ngx_http_upstream_module.html#var_upstream_trailer_
--
-- @param grpc_status number grpc status code
-- @param grpc_message string grpc message
-- @return string grpc-web trailer chunk in raw string
local build_trailer = function (grpc_status, grpc_message)
local status_str = "grpc-status:" .. grpc_status
local status_msg = "grpc-message:" .. ( grpc_message or "")
local grpc_web_trailer = status_str .. "\r\n" .. status_msg .. "\r\n"
local len = #grpc_web_trailer
-- 1 byte: 0x80
local trailer_buf = string.char(0x80)
-- 4 bytes: length of the trailer
trailer_buf = trailer_buf .. string.char(
bit.band(bit.rshift(len, 24), 0xff),
bit.band(bit.rshift(len, 16), 0xff),
bit.band(bit.rshift(len, 8), 0xff),
bit.band(len, 0xff)
)
-- n bytes: trailer
trailer_buf = trailer_buf .. grpc_web_trailer
return trailer_buf
end
function _M.access(conf, ctx)
-- set context variable mime
-- When processing non gRPC Web requests, `mime` can be obtained in the context
-- and set to the `Content-Type` of the response
ctx.grpc_web_mime = core.request.header(ctx, "Content-Type")
local method = core.request.get_method()
if method == ALLOW_METHOD_OPTIONS then
return exit(ctx, 204)
end
if method ~= ALLOW_METHOD_POST then
-- https://github.com/grpc/grpc-web/blob/master/doc/browser-features.md#cors-support
core.log.error("request method: `", method, "` invalid")
return exit(ctx, 405)
end
local encoding = grpc_web_content_encoding[ctx.grpc_web_mime]
if not encoding then
core.log.error("request Content-Type: `", ctx.grpc_web_mime, "` invalid")
return exit(ctx, 400)
end
-- set context variable encoding method
ctx.grpc_web_encoding = encoding
-- set grpc path
if not (ctx.curr_req_matched and ctx.curr_req_matched[":ext"]) then
core.log.error("routing configuration error, grpc-web plugin only supports ",
"`prefix matching` pattern routing")
return exit(ctx, 400)
end
local path = ctx.curr_req_matched[":ext"]
if path:byte(1) ~= core.string.byte("/") then
path = "/" .. path
end
req_set_uri(path)
-- set grpc body
local body, err = core.request.get_body()
if err or not body then
core.log.error("failed to read request body, err: ", err)
return exit(ctx, 400)
end
if encoding == CONTENT_ENCODING_BASE64 then
body = decode_base64(body)
if not body then
core.log.error("failed to decode request body")
return exit(ctx, 400)
end
end
-- set grpc content-type
core.request.set_header(ctx, "Content-Type", DEFAULT_PROXY_CONTENT_TYPE)
-- set grpc body
req_set_body_data(body)
end
function _M.header_filter(conf, ctx)
local method = core.request.get_method()
if method == ALLOW_METHOD_OPTIONS then
core.response.set_header("Access-Control-Allow-Methods", DEFAULT_CORS_ALLOW_METHODS)
core.response.set_header("Access-Control-Allow-Headers", conf.cors_allow_headers)
end
if not ctx.cors_allow_origins then
core.response.set_header("Access-Control-Allow-Origin", DEFAULT_CORS_ALLOW_ORIGIN)
end
core.response.set_header("Access-Control-Expose-Headers", DEFAULT_CORS_EXPOSE_HEADERS)
if not ctx.grpc_web_skip_body_filter then
core.response.set_header("Content-Type", ctx.grpc_web_mime)
core.response.set_header("Content-Length", nil)
end
end
function _M.body_filter(conf, ctx)
if ctx.grpc_web_skip_body_filter then
return
end
-- If the MIME extension type description of the gRPC-Web standard is not obtained,
-- indicating that the request is not based on the gRPC Web specification,
-- the processing of the request body will be ignored
-- https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md
-- https://github.com/grpc/grpc-web/blob/master/doc/browser-features.md#cors-support
if not ctx.grpc_web_mime then
return
end
if ctx.grpc_web_encoding == CONTENT_ENCODING_BASE64 then
local chunk = ngx_arg[1]
chunk = encode_base64(chunk)
ngx_arg[1] = chunk
end
if ngx_arg[2] then -- if eof
local status = ctx.var.upstream_trailer_grpc_status
local message = ctx.var.upstream_trailer_grpc_message
-- When the response body completes and still does not receive the grpc status
local resp_ok = status ~= nil and status ~= ""
local trailer_buf = build_trailer(
resp_ok and status or 2,
resp_ok and message or "upstream grpc status not received"
)
if ctx.grpc_web_encoding == CONTENT_ENCODING_BASE64 then
trailer_buf = encode_base64(trailer_buf)
end
ngx_arg[1] = ngx_arg[1] .. trailer_buf
end
end
return _M

View File

@@ -0,0 +1,170 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local is_apisix_or, response = pcall(require, "resty.apisix.response")
local ngx_header = ngx.header
local req_http_version = ngx.req.http_version
local str_sub = string.sub
local ipairs = ipairs
local tonumber = tonumber
local type = type
local schema = {
type = "object",
properties = {
types = {
anyOf = {
{
type = "array",
minItems = 1,
items = {
type = "string",
minLength = 1,
},
},
{
enum = {"*"}
}
},
default = {"text/html"}
},
min_length = {
type = "integer",
minimum = 1,
default = 20,
},
comp_level = {
type = "integer",
minimum = 1,
maximum = 9,
default = 1,
},
http_version = {
enum = {1.1, 1.0},
default = 1.1,
},
buffers = {
type = "object",
properties = {
number = {
type = "integer",
minimum = 1,
default = 32,
},
size = {
type = "integer",
minimum = 1,
default = 4096,
}
},
default = {
number = 32,
size = 4096,
}
},
vary = {
type = "boolean",
}
},
}
local plugin_name = "gzip"
local _M = {
version = 0.1,
priority = 995,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.header_filter(conf, ctx)
if not is_apisix_or then
core.log.error("need to build APISIX-Runtime to support setting gzip")
return 501
end
local types = conf.types
local content_type = ngx_header["Content-Type"]
if not content_type then
-- Like Nginx, don't gzip if Content-Type is missing
return
end
if type(types) == "table" then
local matched = false
local from = core.string.find(content_type, ";")
if from then
content_type = str_sub(content_type, 1, from - 1)
end
for _, ty in ipairs(types) do
if content_type == ty then
matched = true
break
end
end
if not matched then
return
end
end
local content_length = tonumber(ngx_header["Content-Length"])
if content_length then
local min_length = conf.min_length
if content_length < min_length then
return
end
-- Like Nginx, don't check min_length if Content-Length is missing
end
local http_version = req_http_version()
if http_version < conf.http_version then
return
end
local buffers = conf.buffers
core.log.info("set gzip with buffers: ", buffers.number, " ", buffers.size,
", level: ", conf.comp_level)
local ok, err = response.set_gzip({
buffer_num = buffers.number,
buffer_size = buffers.size,
compress_level = conf.comp_level,
})
if not ok then
core.log.error("failed to set gzip: ", err)
return
end
if conf.vary then
core.response.add_header("Vary", "Accept-Encoding")
end
end
return _M

View File

@@ -0,0 +1,372 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local ngx = ngx
local abs = math.abs
local ngx_time = ngx.time
local ngx_re = require("ngx.re")
local ipairs = ipairs
local hmac_sha1 = ngx.hmac_sha1
local core = require("apisix.core")
local hmac = require("resty.hmac")
local consumer = require("apisix.consumer")
local ngx_decode_base64 = ngx.decode_base64
local ngx_encode_base64 = ngx.encode_base64
local plugin_name = "hmac-auth"
local ALLOWED_ALGORITHMS = {"hmac-sha1", "hmac-sha256", "hmac-sha512"}
local resty_sha256 = require("resty.sha256")
local schema_def = require("apisix.schema_def")
local auth_utils = require("apisix.utils.auth")
local schema = {
type = "object",
title = "work with route or service object",
properties = {
allowed_algorithms = {
type = "array",
minItems = 1,
items = {
type = "string",
enum = ALLOWED_ALGORITHMS
},
default = ALLOWED_ALGORITHMS,
},
clock_skew = {
type = "integer",
default = 300,
minimum = 1
},
signed_headers = {
type = "array",
items = {
type = "string",
minLength = 1,
maxLength = 50,
}
},
validate_request_body = {
type = "boolean",
title = "A boolean value telling the plugin to enable body validation",
default = false,
},
hide_credentials = {type = "boolean", default = false},
anonymous_consumer = schema_def.anonymous_consumer_schema,
},
}
local consumer_schema = {
type = "object",
title = "work with consumer object",
properties = {
key_id = {type = "string", minLength = 1, maxLength = 256},
secret_key = {type = "string", minLength = 1, maxLength = 256},
},
encrypt_fields = {"secret_key"},
required = {"key_id", "secret_key"},
}
local _M = {
version = 0.1,
priority = 2530,
type = 'auth',
name = plugin_name,
schema = schema,
consumer_schema = consumer_schema
}
local hmac_funcs = {
["hmac-sha1"] = function(secret_key, message)
return hmac_sha1(secret_key, message)
end,
["hmac-sha256"] = function(secret_key, message)
return hmac:new(secret_key, hmac.ALGOS.SHA256):final(message)
end,
["hmac-sha512"] = function(secret_key, message)
return hmac:new(secret_key, hmac.ALGOS.SHA512):final(message)
end,
}
local function array_to_map(arr)
local map = core.table.new(0, #arr)
for _, v in ipairs(arr) do
map[v] = true
end
return map
end
function _M.check_schema(conf, schema_type)
core.log.info("input conf: ", core.json.delay_encode(conf))
if schema_type == core.schema.TYPE_CONSUMER then
return core.schema.check(consumer_schema, conf)
else
return core.schema.check(schema, conf)
end
end
local function get_consumer(key_id)
if not key_id then
return nil, "missing key_id"
end
local cur_consumer, _, err = consumer.find_consumer(plugin_name, "key_id", key_id)
if not cur_consumer then
return nil, err or "Invalid key_id"
end
core.log.info("consumer: ", core.json.delay_encode(consumer, true))
return cur_consumer
end
local function generate_signature(ctx, secret_key, params)
local uri = ctx.var.request_uri
local request_method = core.request.get_method()
if uri == "" then
uri = "/"
end
local signing_string_items = {
params.keyId,
}
if params.headers then
for _, h in ipairs(params.headers) do
local canonical_header = core.request.header(ctx, h)
if not canonical_header then
if h == "@request-target" then
local request_target = request_method .. " " .. uri
core.table.insert(signing_string_items, request_target)
core.log.info("canonical_header name:", core.json.delay_encode(h))
core.log.info("canonical_header value: ",
core.json.delay_encode(request_target))
end
else
core.table.insert(signing_string_items,
h .. ": " .. canonical_header)
core.log.info("canonical_header name:", core.json.delay_encode(h))
core.log.info("canonical_header value: ",
core.json.delay_encode(canonical_header))
end
end
end
local signing_string = core.table.concat(signing_string_items, "\n") .. "\n"
return hmac_funcs[params.algorithm](secret_key, signing_string)
end
local function sha256(key)
local hash = resty_sha256:new()
hash:update(key)
local digest = hash:final()
return digest
end
local function validate(ctx, conf, params)
if not params then
return nil
end
if not params.keyId or not params.signature then
return nil, "keyId or signature missing"
end
if not params.algorithm then
return nil, "algorithm missing"
end
local consumer, err = get_consumer(params.keyId)
if err then
return nil, err
end
local consumer_conf = consumer.auth_conf
local found_algorithm = false
-- check supported algorithm used
if not conf.allowed_algorithms then
conf.allowed_algorithms = ALLOWED_ALGORITHMS
end
for _, algo in ipairs(conf.allowed_algorithms) do
if algo == params.algorithm then
found_algorithm = true
break
end
end
if not found_algorithm then
return nil, "Invalid algorithm"
end
core.log.info("clock_skew: ", conf.clock_skew)
if conf.clock_skew and conf.clock_skew > 0 then
if not params.date then
return nil, "Date header missing. failed to validate clock skew"
end
local time = ngx.parse_http_time(params.date)
core.log.info("params.date: ", params.date, " time: ", time)
if not time then
return nil, "Invalid GMT format time"
end
local diff = abs(ngx_time() - time)
if diff > conf.clock_skew then
return nil, "Clock skew exceeded"
end
end
-- validate headers
-- All headers passed in route conf.signed_headers must be used in signing(params.headers)
if conf.signed_headers and #conf.signed_headers >= 1 then
if not params.headers then
return nil, "headers missing"
end
local params_headers_map = array_to_map(params.headers)
if params_headers_map then
for _, header in ipairs(conf.signed_headers) do
if not params_headers_map[header] then
return nil, [[expected header "]] .. header .. [[" missing in signing]]
end
end
end
end
local secret_key = consumer_conf and consumer_conf.secret_key
local request_signature = ngx_decode_base64(params.signature)
local generated_signature = generate_signature(ctx, secret_key, params)
if request_signature ~= generated_signature then
return nil, "Invalid signature"
end
local validate_request_body = conf.validate_request_body
if validate_request_body then
local digest_header = params.body_digest
if not digest_header then
return nil, "Invalid digest"
end
local req_body, err = core.request.get_body()
if err then
return nil, err
end
req_body = req_body or ""
local digest_created = "SHA-256" .. "=" ..
ngx_encode_base64(sha256(req_body))
if digest_created ~= digest_header then
return nil, "Invalid digest"
end
end
return consumer
end
local function retrieve_hmac_fields(ctx)
local hmac_params = {}
local auth_string = core.request.header(ctx, "Authorization")
if not auth_string then
return nil, "missing Authorization header"
end
if not core.string.has_prefix(auth_string, "Signature") then
return nil, "Authorization header does not start with 'Signature'"
end
local signature_fields = auth_string:sub(10):gmatch('[^,]+')
for field in signature_fields do
local key, value = field:match('%s*(%w+)="(.-)"')
if key and value then
if key == "keyId" or key == "algorithm" or key == "signature" then
hmac_params[key] = value
elseif key == "headers" then
hmac_params.headers = ngx_re.split(value, " ")
end
end
end
-- will be required to check clock skew
if core.request.header(ctx, "Date") then
hmac_params.date = core.request.header(ctx, "Date")
end
if core.request.header(ctx, "Digest") then
hmac_params.body_digest = core.request.header(ctx, "Digest")
end
return hmac_params
end
local function find_consumer(conf, ctx)
local params,err = retrieve_hmac_fields(ctx)
if err then
if not auth_utils.is_running_under_multi_auth(ctx) then
core.log.warn("client request can't be validated: ", err)
end
return nil, nil, "client request can't be validated: " .. err
end
local validated_consumer, err = validate(ctx, conf, params)
if not validated_consumer then
err = "client request can't be validated: " .. (err or "Invalid signature")
if auth_utils.is_running_under_multi_auth(ctx) then
return nil, nil, err
end
core.log.warn(err)
return nil, nil, "client request can't be validated"
end
local consumers_conf = consumer.consumers_conf(plugin_name)
return validated_consumer, consumers_conf, err
end
function _M.rewrite(conf, ctx)
local cur_consumer, consumers_conf, err = find_consumer(conf, ctx)
if not cur_consumer then
if not conf.anonymous_consumer then
return 401, { message = err }
end
cur_consumer, consumers_conf, err = consumer.get_anonymous_consumer(conf.anonymous_consumer)
if not cur_consumer then
if auth_utils.is_running_under_multi_auth(ctx) then
return 401, err
end
core.log.error(err)
return 401, { message = "Invalid user authorization" }
end
end
if conf.hide_credentials then
core.request.set_header("Authorization", nil)
end
consumer.attach_consumer(ctx, cur_consumer, consumers_conf)
end
return _M

View File

@@ -0,0 +1,262 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local require = require
local core = require("apisix.core")
local pairs = pairs
local str_format = string.format
local bit = require("bit")
local rshift = bit.rshift
local band = bit.band
local char = string.char
local tostring = tostring
local ngx = ngx
local type = type
local plugin_name = "http-dubbo"
local schema = {
type = "object",
properties = {
service_name = {
type = "string",
minLength = 1,
},
service_version = {
type = "string",
pattern = [[^\d+\.\d+\.\d+]],
default ="0.0.0"
},
method = {
type = "string",
minLength = 1,
},
params_type_desc = {
type = "string",
default = ""
},
serialization_header_key = {
type = "string"
},
serialized = {
type = "boolean",
default = false
},
connect_timeout={
type = "number",
default = 6000
},
read_timeout={
type = "number",
default = 6000
},
send_timeout={
type = "number",
default = 6000
}
},
required = { "service_name", "method" },
}
local _M = {
version = 0.1,
priority = 504,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
local function str_int32(int)
return char(band(rshift(int, 24), 0xff),
band(rshift(int, 16), 0xff),
band(rshift(int, 8), 0xff),
band(int, 0xff))
end
local function parse_dubbo_header(header)
for i = 1, 16 do
local currentByte = header:byte(i)
if not currentByte then
return nil
end
end
local magic_number = str_format("%04x", header:byte(1) * 256 + header:byte(2))
local message_flag = header:byte(3)
local status = header:byte(4)
local request_id = 0
for i = 5, 12 do
request_id = request_id * 256 + header:byte(i)
end
local byte13Val = header:byte(13) * 256 * 256 * 256
local byte14Val = header:byte(14) * 256 * 256
local data_length = byte13Val + byte14Val + header:byte(15) * 256 + header:byte(16)
local is_request = bit.band(bit.rshift(message_flag, 7), 0x01) == 1 and 1 or 0
local is_two_way = bit.band(bit.rshift(message_flag, 6), 0x01) == 1 and 1 or 0
local is_event = bit.band(bit.rshift(message_flag, 5), 0x01) == 1 and 1 or 0
return {
magic_number = magic_number,
message_flag = message_flag,
is_request = is_request,
is_two_way = is_two_way,
is_event = is_event,
status = status,
request_id = request_id,
data_length = data_length
}
end
local function string_to_json_string(str)
local result = "\""
for i = 1, #str do
local byte = core.string.sub(str, i, i)
if byte == "\\" then
result = result .. "\\\\"
elseif byte == "\n" then
result = result .. "\\n"
elseif byte == "\t" then
result = result .. "\\t"
elseif byte == "\r" then
result = result .. "\\r"
elseif byte == "\b" then
result = result .. "\\b"
elseif byte == "\f" then
result = result .. "\\f"
elseif byte == "\"" then
result = result .. "\\\""
else
result = result .. byte
end
end
return result .. "\""
end
local function get_dubbo_request(conf, ctx)
-- use dubbo and fastjson
local first_byte4 = "\xda\xbb\xc6\x00"
local requestId = "\x00\x00\x00\x00\x00\x00\x00\x01"
local version = "\"2.0.2\"\n"
local service = "\"" .. conf.service_name .. "\"" .. "\n"
local service_version = "\"" .. conf.service_version .. "\"" .. "\n"
local method_name = "\"" .. conf.method .. "\"" .. "\n"
local params_desc = "\"" .. conf.params_type_desc .. "\"" .. "\n"
local params = ""
local serialized = conf.serialized
if conf.serialization_header_key then
local serialization_header = core.request.header(ctx, conf.serialization_header_key)
serialized = serialization_header == "true"
end
if serialized then
params = core.request.get_body()
if params then
local end_of_params = core.string.sub(params, -1)
if end_of_params ~= "\n" then
params = params .. "\n"
end
end
else
local body_data = core.request.get_body()
if body_data then
local lua_object = core.json.decode(body_data);
for _, v in pairs(lua_object) do
local pt = type(v)
if pt == "nil" then
params = params .. "null" .. "\n"
elseif pt == "string" then
params = params .. string_to_json_string(v) .. "\n"
elseif pt == "number" then
params = params .. tostring(v) .. "\n"
else
params = params .. core.json.encode(v) .. "\n"
end
end
end
end
local attachments = "{}\n"
if params == nil then
params = ""
end
local payload = #version + #service + #service_version
+ #method_name + #params_desc + #params + #attachments
return {
first_byte4,
requestId,
str_int32(payload),
version,
service,
service_version,
method_name,
params_desc,
params,
attachments
}
end
function _M.before_proxy(conf, ctx)
local sock = ngx.socket.tcp()
sock:settimeouts(conf.connect_timeout, conf.send_timeout, conf.read_timeout)
local ok, err = sock:connect(ctx.picked_server.host, ctx.picked_server.port)
if not ok then
sock:close()
core.log.error("failed to connect to upstream ", err)
return 502
end
local request = get_dubbo_request(conf, ctx)
local bytes, _ = sock:send(request)
if bytes > 0 then
local header, _ = sock:receiveany(16);
if header then
local header_info = parse_dubbo_header(header)
if header_info and header_info.status == 20 then
local readline = sock:receiveuntil("\n")
local body_status, _, _ = readline()
if body_status then
local response_status = core.string.sub(body_status, 1, 1)
if response_status == "2" or response_status == "5" then
sock:close()
return 200
elseif response_status == "1" or response_status == "4" then
local body, _, _ = readline()
sock:close()
return 200, body
end
end
end
end
end
sock:close()
return 500
end
return _M

View File

@@ -0,0 +1,223 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local log_util = require("apisix.utils.log-util")
local core = require("apisix.core")
local http = require("resty.http")
local url = require("net.url")
local tostring = tostring
local ipairs = ipairs
local plugin_name = "http-logger"
local batch_processor_manager = bp_manager_mod.new("http logger")
local schema = {
type = "object",
properties = {
uri = core.schema.uri_def,
auth_header = {type = "string"},
timeout = {type = "integer", minimum = 1, default = 3},
log_format = {type = "object"},
include_req_body = {type = "boolean", default = false},
include_req_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
include_resp_body = {type = "boolean", default = false},
include_resp_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
concat_method = {type = "string", default = "json",
enum = {"json", "new_line"}},
ssl_verify = {type = "boolean", default = false},
},
required = {"uri"}
}
local metadata_schema = {
type = "object",
properties = {
log_format = {
type = "object"
}
},
}
local _M = {
version = 0.1,
priority = 410,
name = plugin_name,
schema = batch_processor_manager:wrap_schema(schema),
metadata_schema = metadata_schema,
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local check = {"uri"}
core.utils.check_https(check, conf, plugin_name)
core.utils.check_tls_bool({"ssl_verify"}, conf, plugin_name)
local ok, err = core.schema.check(schema, conf)
if not ok then
return nil, err
end
return log_util.check_log_schema(conf)
end
local function send_http_data(conf, log_message)
local err_msg
local res = true
local url_decoded = url.parse(conf.uri)
local host = url_decoded.host
local port = url_decoded.port
core.log.info("sending a batch logs to ", conf.uri)
if ((not port) and url_decoded.scheme == "https") then
port = 443
elseif not port then
port = 80
end
local httpc = http.new()
httpc:set_timeout(conf.timeout * 1000)
local ok, err = httpc:connect(host, port)
if not ok then
return false, "failed to connect to host[" .. host .. "] port["
.. tostring(port) .. "] " .. err
end
if url_decoded.scheme == "https" then
ok, err = httpc:ssl_handshake(true, host, conf.ssl_verify)
if not ok then
return false, "failed to perform SSL with host[" .. host .. "] "
.. "port[" .. tostring(port) .. "] " .. err
end
end
local content_type
if conf.concat_method == "json" then
content_type = "application/json"
else
content_type = "text/plain"
end
local httpc_res, httpc_err = httpc:request({
method = "POST",
path = #url_decoded.path ~= 0 and url_decoded.path or "/",
query = url_decoded.query,
body = log_message,
headers = {
["Host"] = url_decoded.host,
["Content-Type"] = content_type,
["Authorization"] = conf.auth_header
}
})
if not httpc_res then
return false, "error while sending data to [" .. host .. "] port["
.. tostring(port) .. "] " .. httpc_err
end
-- some error occurred in the server
if httpc_res.status >= 400 then
res = false
err_msg = "server returned status code[" .. httpc_res.status .. "] host["
.. host .. "] port[" .. tostring(port) .. "] "
.. "body[" .. httpc_res:read_body() .. "]"
end
return res, err_msg
end
function _M.body_filter(conf, ctx)
log_util.collect_body(conf, ctx)
end
function _M.log(conf, ctx)
local entry = log_util.get_log_entry(plugin_name, conf, ctx)
if not entry.route_id then
entry.route_id = "no-matched"
end
if batch_processor_manager:add_entry(conf, entry) then
return
end
-- Generate a function to be executed by the batch processor
local func = function(entries, batch_max_size)
local data, err
if conf.concat_method == "json" then
if batch_max_size == 1 then
data, err = core.json.encode(entries[1]) -- encode as single {}
else
data, err = core.json.encode(entries) -- encode as array [{}]
end
elseif conf.concat_method == "new_line" then
if batch_max_size == 1 then
data, err = core.json.encode(entries[1]) -- encode as single {}
else
local t = core.table.new(#entries, 0)
for i, entry in ipairs(entries) do
t[i], err = core.json.encode(entry)
if err then
core.log.warn("failed to encode http log: ", err, ", log data: ", entry)
break
end
end
data = core.table.concat(t, "\n") -- encode as multiple string
end
else
-- defensive programming check
err = "unknown concat_method " .. (conf.concat_method or "nil")
end
if not data then
return false, 'error occurred while encoding the data: ' .. err
end
return send_http_data(conf, data)
end
batch_processor_manager:add_entry_to_new_processor(conf, entry, ctx, func)
end
return _M

View File

@@ -0,0 +1,61 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local plugin = require("apisix.plugin")
local inspect = require("apisix.inspect")
local plugin_name = "inspect"
local schema = {
type = "object",
properties = {},
}
local _M = {
version = 0.1,
priority = 200,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf, schema_type)
return core.schema.check(schema, conf)
end
function _M.init()
local attr = plugin.plugin_attr(plugin_name)
local delay
local hooks_file
if attr then
delay = attr.delay
hooks_file = attr.hooks_file
end
core.log.info("delay=", delay, ", hooks_file=", hooks_file)
return inspect.init(delay, hooks_file)
end
function _M.destroy()
return inspect.destroy()
end
return _M

View File

@@ -0,0 +1,26 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local base = require("apisix.plugins.ip-restriction.init")
-- avoid unexpected data sharing
local ip_restriction = core.table.clone(base)
ip_restriction.access = base.restrict
return ip_restriction

View File

@@ -0,0 +1,122 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local ipairs = ipairs
local core = require("apisix.core")
local lrucache = core.lrucache.new({
ttl = 300, count = 512
})
local schema = {
type = "object",
properties = {
message = {
type = "string",
minLength = 1,
maxLength = 1024,
default = "Your IP address is not allowed"
},
response_code = {
type = "integer",
minimum = 403,
maximum = 404,
default = 403
},
whitelist = {
type = "array",
items = {anyOf = core.schema.ip_def},
minItems = 1
},
blacklist = {
type = "array",
items = {anyOf = core.schema.ip_def},
minItems = 1
},
},
oneOf = {
{required = {"whitelist"}},
{required = {"blacklist"}},
},
}
local plugin_name = "ip-restriction"
local _M = {
version = 0.1,
priority = 3000,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
-- we still need this as it is too complex to filter out all invalid IPv6 via regex
if conf.whitelist then
for _, cidr in ipairs(conf.whitelist) do
if not core.ip.validate_cidr_or_ip(cidr) then
return false, "invalid ip address: " .. cidr
end
end
end
if conf.blacklist then
for _, cidr in ipairs(conf.blacklist) do
if not core.ip.validate_cidr_or_ip(cidr) then
return false, "invalid ip address: " .. cidr
end
end
end
return true
end
function _M.restrict(conf, ctx)
local block = false
local remote_addr = ctx.var.remote_addr
if conf.blacklist then
local matcher = lrucache(conf.blacklist, nil,
core.ip.create_ip_matcher, conf.blacklist)
if matcher then
block = matcher:match(remote_addr)
end
end
if conf.whitelist then
local matcher = lrucache(conf.whitelist, nil,
core.ip.create_ip_matcher, conf.whitelist)
if matcher then
block = not matcher:match(remote_addr)
end
end
if block then
return conf.response_code, { message = conf.message }
end
end
return _M

View File

@@ -0,0 +1,279 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local consumer_mod = require("apisix.consumer")
local base64 = require("ngx.base64")
local aes = require("resty.aes")
local ngx = ngx
local sub_str = string.sub
local cipher = aes.cipher(256, "gcm")
local plugin_name = "jwe-decrypt"
local schema = {
type = "object",
properties = {
header = {
type = "string",
default = "Authorization"
},
forward_header = {
type = "string",
default = "Authorization"
},
strict = {
type = "boolean",
default = true
}
},
required = { "header", "forward_header" },
}
local consumer_schema = {
type = "object",
properties = {
key = { type = "string" },
secret = { type = "string" },
is_base64_encoded = { type = "boolean" },
},
required = { "key", "secret" },
encrypt_fields = { "key", "secret" },
}
local _M = {
version = 0.1,
priority = 2509,
type = 'auth',
name = plugin_name,
schema = schema,
consumer_schema = consumer_schema
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_CONSUMER then
local ok, err = core.schema.check(consumer_schema, conf)
if not ok then
return false, err
end
local local_conf, err = core.config.local_conf(true)
if not local_conf then
return false, "failed to load the configuration file: " .. err
end
local encrypted = core.table.try_read_attr(local_conf, "apisix", "data_encryption",
"enable_encrypt_fields") and (core.config.type == "etcd")
-- if encrypted, the secret length will exceed 32 so don't check
if not encrypted then
-- restrict the length of secret, we use A256GCM for encryption,
-- so the length should be 32 chars only
if conf.is_base64_encoded then
if #base64.decode_base64url(conf.secret) ~= 32 then
return false, "the secret length after base64 decode should be 32 chars"
end
else
if #conf.secret ~= 32 then
return false, "the secret length should be 32 chars"
end
end
end
return true
end
return core.schema.check(schema, conf)
end
local function get_secret(conf)
local secret = conf.secret
if conf.is_base64_encoded then
return base64.decode_base64url(secret)
end
return secret
end
local function load_jwe_token(jwe_token)
local o = { valid = false }
o.header, o.enckey, o.iv, o.ciphertext, o.tag = jwe_token:match("(.-)%.(.-)%.(.-)%.(.-)%.(.*)")
if not o.header then
return o
end
local he = base64.decode_base64url(o.header)
if not he then
return o
end
o.header_obj = core.json.decode(he)
if not o.header_obj then
return o
end
o.valid = true
return o
end
local function jwe_decrypt_with_obj(o, consumer)
local secret = get_secret(consumer.auth_conf)
local dec = base64.decode_base64url
local aes_default = aes:new(
secret,
nil,
cipher,
{iv = dec(o.iv)}
)
local decrypted = aes_default:decrypt(dec(o.ciphertext), dec(o.tag))
return decrypted
end
local function jwe_encrypt(o, consumer)
local secret = get_secret(consumer.auth_conf)
local enc = base64.encode_base64url
local aes_default = aes:new(
secret,
nil,
cipher,
{iv = o.iv})
local encrypted = aes_default:encrypt(o.plaintext)
o.ciphertext = encrypted[1]
o.tag = encrypted[2]
return o.header .. ".." .. enc(o.iv) .. "." .. enc(o.ciphertext) .. "." .. enc(o.tag)
end
local function get_consumer(key)
local consumer_conf = consumer_mod.plugin(plugin_name)
if not consumer_conf then
return nil
end
local consumers = consumer_mod.consumers_kv(plugin_name, consumer_conf, "key")
if not consumers then
return nil
end
core.log.info("consumers: ", core.json.delay_encode(consumers))
return consumers[key]
end
local function fetch_jwe_token(conf, ctx)
local token = core.request.header(ctx, conf.header)
if token then
local prefix = sub_str(token, 1, 7)
if prefix == 'Bearer ' or prefix == 'bearer ' then
return sub_str(token, 8)
end
return token
end
end
function _M.rewrite(conf, ctx)
-- fetch token and hide credentials if necessary
local jwe_token, err = fetch_jwe_token(conf, ctx)
if not jwe_token and conf.strict then
core.log.info("failed to fetch JWE token: ", err)
return 403, { message = "missing JWE token in request" }
end
local jwe_obj = load_jwe_token(jwe_token)
if not jwe_obj.valid then
return 400, { message = "JWE token invalid" }
end
if not jwe_obj.header_obj.kid then
return 400, { message = "missing kid in JWE token" }
end
local consumer = get_consumer(jwe_obj.header_obj.kid)
if not consumer then
return 400, { message = "invalid kid in JWE token" }
end
local plaintext, err = jwe_decrypt_with_obj(jwe_obj, consumer)
if err ~= nil then
return 400, { message = "failed to decrypt JWE token" }
end
core.request.set_header(ctx, conf.forward_header, plaintext)
end
local function gen_token()
local args = core.request.get_uri_args()
if not args or not args.key then
return core.response.exit(400)
end
local key = args.key
local payload = args.payload
if payload then
payload = ngx.unescape_uri(payload)
end
local consumer = get_consumer(key)
if not consumer then
return core.response.exit(404)
end
core.log.info("consumer: ", core.json.delay_encode(consumer))
local iv = args.iv
if not iv then
-- TODO: random bytes
iv = "123456789012"
end
local obj = {
iv = iv,
plaintext = payload,
header_obj = {
kid = key,
alg = "dir",
enc = "A256GCM",
},
}
obj.header = base64.encode_base64url(core.json.encode(obj.header_obj))
local jwe_token = jwe_encrypt(obj, consumer)
if jwe_token then
return core.response.exit(200, jwe_token)
end
return core.response.exit(404)
end
function _M.api()
return {
{
methods = { "GET" },
uri = "/apisix/plugin/jwe/encrypt",
handler = gen_token,
}
}
end
return _M

View File

@@ -0,0 +1,331 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local jwt = require("resty.jwt")
local consumer_mod = require("apisix.consumer")
local resty_random = require("resty.random")
local new_tab = require ("table.new")
local auth_utils = require("apisix.utils.auth")
local ngx_encode_base64 = ngx.encode_base64
local ngx_decode_base64 = ngx.decode_base64
local ngx = ngx
local sub_str = string.sub
local table_insert = table.insert
local table_concat = table.concat
local ngx_re_gmatch = ngx.re.gmatch
local plugin_name = "jwt-auth"
local schema_def = require("apisix.schema_def")
local schema = {
type = "object",
properties = {
header = {
type = "string",
default = "authorization"
},
query = {
type = "string",
default = "jwt"
},
cookie = {
type = "string",
default = "jwt"
},
hide_credentials = {
type = "boolean",
default = false
},
key_claim_name = {
type = "string",
default = "key",
minLength = 1,
},
store_in_ctx = {
type = "boolean",
default = false
},
anonymous_consumer = schema_def.anonymous_consumer_schema,
},
}
local consumer_schema = {
type = "object",
-- can't use additionalProperties with dependencies
properties = {
key = {
type = "string",
minLength = 1,
},
secret = {
type = "string",
minLength = 1,
},
algorithm = {
type = "string",
enum = {"HS256", "HS512", "RS256", "ES256"},
default = "HS256"
},
exp = {type = "integer", minimum = 1, default = 86400},
base64_secret = {
type = "boolean",
default = false
},
lifetime_grace_period = {
type = "integer",
minimum = 0,
default = 0
}
},
dependencies = {
algorithm = {
oneOf = {
{
properties = {
algorithm = {
enum = {"HS256", "HS512"},
default = "HS256"
},
},
},
{
properties = {
public_key = {type = "string"},
algorithm = {
enum = {"RS256", "ES256"},
},
},
required = {"public_key"},
},
}
}
},
encrypt_fields = {"secret"},
required = {"key"},
}
local _M = {
version = 0.1,
priority = 2510,
type = 'auth',
name = plugin_name,
schema = schema,
consumer_schema = consumer_schema
}
function _M.check_schema(conf, schema_type)
core.log.info("input conf: ", core.json.delay_encode(conf))
local ok, err
if schema_type == core.schema.TYPE_CONSUMER then
ok, err = core.schema.check(consumer_schema, conf)
else
return core.schema.check(schema, conf)
end
if not ok then
return false, err
end
if conf.algorithm ~= "RS256" and conf.algorithm ~= "ES256" and not conf.secret then
conf.secret = ngx_encode_base64(resty_random.bytes(32, true))
elseif conf.base64_secret then
if ngx_decode_base64(conf.secret) == nil then
return false, "base64_secret required but the secret is not in base64 format"
end
end
return true
end
local function remove_specified_cookie(src, key)
local cookie_key_pattern = "([a-zA-Z0-9-_]*)"
local cookie_val_pattern = "([a-zA-Z0-9-._]*)"
local t = new_tab(1, 0)
local it, err = ngx_re_gmatch(src, cookie_key_pattern .. "=" .. cookie_val_pattern, "jo")
if not it then
core.log.error("match origins failed: ", err)
return src
end
while true do
local m, err = it()
if err then
core.log.error("iterate origins failed: ", err)
return src
end
if not m then
break
end
if m[1] ~= key then
table_insert(t, m[0])
end
end
return table_concat(t, "; ")
end
local function fetch_jwt_token(conf, ctx)
local token = core.request.header(ctx, conf.header)
if token then
if conf.hide_credentials then
-- hide for header
core.request.set_header(ctx, conf.header, nil)
end
local prefix = sub_str(token, 1, 7)
if prefix == 'Bearer ' or prefix == 'bearer ' then
return sub_str(token, 8)
end
return token
end
local uri_args = core.request.get_uri_args(ctx) or {}
token = uri_args[conf.query]
if token then
if conf.hide_credentials then
-- hide for query
uri_args[conf.query] = nil
core.request.set_uri_args(ctx, uri_args)
end
return token
end
local val = ctx.var["cookie_" .. conf.cookie]
if not val then
return nil, "JWT not found in cookie"
end
if conf.hide_credentials then
-- hide for cookie
local src = core.request.header(ctx, "Cookie")
local reset_val = remove_specified_cookie(src, conf.cookie)
core.request.set_header(ctx, "Cookie", reset_val)
end
return val
end
local function get_secret(conf)
local secret = conf.secret
if conf.base64_secret then
return ngx_decode_base64(secret)
end
return secret
end
local function get_auth_secret(auth_conf)
if not auth_conf.algorithm or auth_conf.algorithm == "HS256"
or auth_conf.algorithm == "HS512" then
return get_secret(auth_conf)
elseif auth_conf.algorithm == "RS256" or auth_conf.algorithm == "ES256" then
return auth_conf.public_key
end
end
local function find_consumer(conf, ctx)
-- fetch token and hide credentials if necessary
local jwt_token, err = fetch_jwt_token(conf, ctx)
if not jwt_token then
core.log.info("failed to fetch JWT token: ", err)
return nil, nil, "Missing JWT token in request"
end
local jwt_obj = jwt:load_jwt(jwt_token)
core.log.info("jwt object: ", core.json.delay_encode(jwt_obj))
if not jwt_obj.valid then
err = "JWT token invalid: " .. jwt_obj.reason
if auth_utils.is_running_under_multi_auth(ctx) then
return nil, nil, err
end
core.log.warn(err)
return nil, nil, "JWT token invalid"
end
local key_claim_name = conf.key_claim_name
local user_key = jwt_obj.payload and jwt_obj.payload[key_claim_name]
if not user_key then
return nil, nil, "missing user key in JWT token"
end
local consumer, consumer_conf, err = consumer_mod.find_consumer(plugin_name, "key", user_key)
if not consumer then
core.log.warn("failed to find consumer: ", err or "invalid user key")
return nil, nil, "Invalid user key in JWT token"
end
core.log.info("consumer: ", core.json.delay_encode(consumer))
local auth_secret, err = get_auth_secret(consumer.auth_conf)
if not auth_secret then
err = "failed to retrieve secrets, err: " .. err
if auth_utils.is_running_under_multi_auth(ctx) then
return nil, nil, err
end
core.log.error(err)
return nil, nil, "failed to verify jwt"
end
local claim_specs = jwt:get_default_validation_options(jwt_obj)
claim_specs.lifetime_grace_period = consumer.auth_conf.lifetime_grace_period
jwt_obj = jwt:verify_jwt_obj(auth_secret, jwt_obj, claim_specs)
core.log.info("jwt object: ", core.json.delay_encode(jwt_obj))
if not jwt_obj.verified then
err = "failed to verify jwt: " .. jwt_obj.reason
if auth_utils.is_running_under_multi_auth(ctx) then
return nil, nil, err
end
core.log.warn(err)
return nil, nil, "failed to verify jwt"
end
if conf.store_in_ctx then
ctx.jwt_auth_payload = jwt_obj.payload
end
return consumer, consumer_conf
end
function _M.rewrite(conf, ctx)
local consumer, consumer_conf, err = find_consumer(conf, ctx)
if not consumer then
if not conf.anonymous_consumer then
return 401, { message = err }
end
consumer, consumer_conf, err = consumer_mod.get_anonymous_consumer(conf.anonymous_consumer)
if not consumer then
err = "jwt-auth failed to authenticate the request, code: 401. error: " .. err
core.log.error(err)
return 401, { message = "Invalid user authorization"}
end
end
core.log.info("consumer: ", core.json.delay_encode(consumer))
consumer_mod.attach_consumer(ctx, consumer, consumer_conf)
core.log.info("hit jwt-auth rewrite")
end
return _M

View File

@@ -0,0 +1,327 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local expr = require("resty.expr.v1")
local core = require("apisix.core")
local log_util = require("apisix.utils.log-util")
local producer = require ("resty.kafka.producer")
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local plugin = require("apisix.plugin")
local math = math
local pairs = pairs
local type = type
local req_read_body = ngx.req.read_body
local plugin_name = "kafka-logger"
local batch_processor_manager = bp_manager_mod.new("kafka logger")
local lrucache = core.lrucache.new({
type = "plugin",
})
local schema = {
type = "object",
properties = {
meta_format = {
type = "string",
default = "default",
enum = {"default", "origin"},
},
log_format = {type = "object"},
-- deprecated, use "brokers" instead
broker_list = {
type = "object",
minProperties = 1,
patternProperties = {
[".*"] = {
description = "the port of kafka broker",
type = "integer",
minimum = 1,
maximum = 65535,
},
},
},
brokers = {
type = "array",
minItems = 1,
items = {
type = "object",
properties = {
host = {
type = "string",
description = "the host of kafka broker",
},
port = {
type = "integer",
minimum = 1,
maximum = 65535,
description = "the port of kafka broker",
},
sasl_config = {
type = "object",
description = "sasl config",
properties = {
mechanism = {
type = "string",
default = "PLAIN",
enum = {"PLAIN"},
},
user = { type = "string", description = "user" },
password = { type = "string", description = "password" },
},
required = {"user", "password"},
},
},
required = {"host", "port"},
},
uniqueItems = true,
},
kafka_topic = {type = "string"},
producer_type = {
type = "string",
default = "async",
enum = {"async", "sync"},
},
required_acks = {
type = "integer",
default = 1,
enum = { 1, -1 },
},
key = {type = "string"},
timeout = {type = "integer", minimum = 1, default = 3},
include_req_body = {type = "boolean", default = false},
include_req_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
include_resp_body = {type = "boolean", default = false},
include_resp_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
max_req_body_bytes = {type = "integer", minimum = 1, default = 524288},
max_resp_body_bytes = {type = "integer", minimum = 1, default = 524288},
-- in lua-resty-kafka, cluster_name is defined as number
-- see https://github.com/doujiang24/lua-resty-kafka#new-1
cluster_name = {type = "integer", minimum = 1, default = 1},
-- config for lua-resty-kafka, default value is same as lua-resty-kafka
producer_batch_num = {type = "integer", minimum = 1, default = 200},
producer_batch_size = {type = "integer", minimum = 0, default = 1048576},
producer_max_buffering = {type = "integer", minimum = 1, default = 50000},
producer_time_linger = {type = "integer", minimum = 1, default = 1},
meta_refresh_interval = {type = "integer", minimum = 1, default = 30},
},
oneOf = {
{ required = {"broker_list", "kafka_topic"},},
{ required = {"brokers", "kafka_topic"},},
}
}
local metadata_schema = {
type = "object",
properties = {
log_format = {
type = "object"
},
max_pending_entries = {
type = "integer",
description = "maximum number of pending entries in the batch processor",
minimum = 1,
},
},
}
local _M = {
version = 0.1,
priority = 403,
name = plugin_name,
schema = batch_processor_manager:wrap_schema(schema),
metadata_schema = metadata_schema,
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local ok, err = core.schema.check(schema, conf)
if not ok then
return nil, err
end
return log_util.check_log_schema(conf)
end
local function get_partition_id(prod, topic, log_message)
if prod.async then
local ringbuffer = prod.ringbuffer
for i = 1, ringbuffer.size, 3 do
if ringbuffer.queue[i] == topic and
ringbuffer.queue[i+2] == log_message then
return math.floor(i / 3)
end
end
core.log.info("current topic in ringbuffer has no message")
return nil
end
-- sync mode
local sendbuffer = prod.sendbuffer
if not sendbuffer.topics[topic] then
core.log.info("current topic in sendbuffer has no message")
return nil
end
for i, message in pairs(sendbuffer.topics[topic]) do
if log_message == message.queue[2] then
return i
end
end
end
local function create_producer(broker_list, broker_config, cluster_name)
core.log.info("create new kafka producer instance")
return producer:new(broker_list, broker_config, cluster_name)
end
local function send_kafka_data(conf, log_message, prod)
local ok, err = prod:send(conf.kafka_topic, conf.key, log_message)
core.log.info("partition_id: ",
core.log.delay_exec(get_partition_id,
prod, conf.kafka_topic, log_message))
if not ok then
return false, "failed to send data to Kafka topic: " .. err ..
", brokers: " .. core.json.encode(conf.broker_list)
end
return true
end
function _M.access(conf, ctx)
if conf.include_req_body then
local should_read_body = true
if conf.include_req_body_expr then
if not conf.request_expr then
local request_expr, err = expr.new(conf.include_req_body_expr)
if not request_expr then
core.log.error('generate request expr err ', err)
return
end
conf.request_expr = request_expr
end
local result = conf.request_expr:eval(ctx.var)
if not result then
should_read_body = false
end
end
if should_read_body then
req_read_body()
end
end
end
function _M.body_filter(conf, ctx)
log_util.collect_body(conf, ctx)
end
function _M.log(conf, ctx)
local metadata = plugin.plugin_metadata(plugin_name)
local max_pending_entries = metadata and metadata.value and
metadata.value.max_pending_entries or nil
local entry
if conf.meta_format == "origin" then
entry = log_util.get_req_original(ctx, conf)
-- core.log.info("origin entry: ", entry)
else
entry = log_util.get_log_entry(plugin_name, conf, ctx)
end
if batch_processor_manager:add_entry(conf, entry, max_pending_entries) then
return
end
-- reuse producer via lrucache to avoid unbalanced partitions of messages in kafka
local broker_list = core.table.clone(conf.brokers or {})
local broker_config = {}
if conf.broker_list then
for host, port in pairs(conf.broker_list) do
local broker = {
host = host,
port = port
}
core.table.insert(broker_list, broker)
end
end
broker_config["request_timeout"] = conf.timeout * 1000
broker_config["producer_type"] = conf.producer_type
broker_config["required_acks"] = conf.required_acks
broker_config["batch_num"] = conf.producer_batch_num
broker_config["batch_size"] = conf.producer_batch_size
broker_config["max_buffering"] = conf.producer_max_buffering
broker_config["flush_time"] = conf.producer_time_linger * 1000
broker_config["refresh_interval"] = conf.meta_refresh_interval * 1000
local prod, err = core.lrucache.plugin_ctx(lrucache, ctx, nil, create_producer,
broker_list, broker_config, conf.cluster_name)
core.log.info("kafka cluster name ", conf.cluster_name, ", broker_list[1] port ",
prod.client.broker_list[1].port)
if err then
return nil, "failed to identify the broker specified: " .. err
end
-- Generate a function to be executed by the batch processor
local func = function(entries, batch_max_size)
local data, err
if batch_max_size == 1 then
data = entries[1]
if type(data) ~= "string" then
data, err = core.json.encode(data) -- encode as single {}
end
else
data, err = core.json.encode(entries) -- encode as array [{}]
end
if not data then
return false, 'error occurred while encoding the data: ' .. err
end
core.log.info("send data to kafka: ", data)
return send_kafka_data(conf, data, prod)
end
batch_processor_manager:add_entry_to_new_processor(conf, entry, ctx, func, max_pending_entries)
end
return _M

View File

@@ -0,0 +1,62 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local schema = {
type = "object",
properties = {
sasl = {
type = "object",
properties = {
username = {
type = "string",
},
password = {
type = "string",
},
},
required = {"username", "password"},
},
},
encrypt_fields = {"sasl.password"},
}
local _M = {
version = 0.1,
priority = 508,
name = "kafka-proxy",
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.access(conf, ctx)
if conf.sasl then
ctx.kafka_consumer_enable_sasl = true
ctx.kafka_consumer_sasl_username = conf.sasl.username
ctx.kafka_consumer_sasl_password = conf.sasl.password
end
end
return _M

View File

@@ -0,0 +1,124 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local consumer_mod = require("apisix.consumer")
local plugin_name = "key-auth"
local schema_def = require("apisix.schema_def")
local schema = {
type = "object",
properties = {
header = {
type = "string",
default = "apikey",
},
query = {
type = "string",
default = "apikey",
},
hide_credentials = {
type = "boolean",
default = false,
},
anonymous_consumer = schema_def.anonymous_consumer_schema,
},
}
local consumer_schema = {
type = "object",
properties = {
key = { type = "string" },
},
encrypt_fields = {"key"},
required = {"key"},
}
local _M = {
version = 0.1,
priority = 2500,
type = 'auth',
name = plugin_name,
schema = schema,
consumer_schema = consumer_schema,
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_CONSUMER then
return core.schema.check(consumer_schema, conf)
else
return core.schema.check(schema, conf)
end
end
local function find_consumer(ctx, conf)
local from_header = true
local key = core.request.header(ctx, conf.header)
if not key then
local uri_args = core.request.get_uri_args(ctx) or {}
key = uri_args[conf.query]
from_header = false
end
if not key then
return nil, nil, "Missing API key in request"
end
local consumer, consumer_conf, err = consumer_mod.find_consumer(plugin_name, "key", key)
if not consumer then
core.log.warn("failed to find consumer: ", err or "invalid api key")
return nil, nil, "Invalid API key in request"
end
core.log.info("consumer: ", core.json.delay_encode(consumer))
if conf.hide_credentials then
if from_header then
core.request.set_header(ctx, conf.header, nil)
else
local args = core.request.get_uri_args(ctx)
args[conf.query] = nil
core.request.set_uri_args(ctx, args)
end
end
return consumer, consumer_conf
end
function _M.rewrite(conf, ctx)
local consumer, consumer_conf, err = find_consumer(ctx, conf)
if not consumer then
if not conf.anonymous_consumer then
return 401, { message = err}
end
consumer, consumer_conf, err = consumer_mod.get_anonymous_consumer(conf.anonymous_consumer)
if not consumer then
err = "key-auth failed to authenticate the request, code: 401. error: " .. err
core.log.error(err)
return 401, { message = "Invalid user authorization"}
end
end
core.log.info("consumer: ", core.json.delay_encode(consumer))
consumer_mod.attach_consumer(ctx, consumer, consumer_conf)
core.log.info("hit key-auth rewrite")
end
return _M

View File

@@ -0,0 +1,229 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local type = type
local pairs = pairs
local math_random = math.random
local ngx = ngx
local http = require("resty.http")
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local core = require("apisix.core")
local str_format = core.string.format
local plugin_name = "lago"
local batch_processor_manager = bp_manager_mod.new("lago logger")
local schema = {
type = "object",
properties = {
-- core configurations
endpoint_addrs = {
type = "array",
minItems = 1,
items = core.schema.uri_def,
description = "Lago API address, like http://127.0.0.1:3000, "
.. "it supports both self-hosted and cloud. If multiple endpoints are"
.. " configured, the log will be pushed to a randomly determined"
.. " endpoint from the list.",
},
endpoint_uri = {
type = "string",
minLength = 1,
default = "/api/v1/events/batch",
description = "Lago API endpoint, it needs to be set to the batch send endpoint.",
},
token = {
type = "string",
description = "Lago API key, create one for your organization on dashboard."
},
event_transaction_id = {
type = "string",
description = "Event's transaction ID, it is used to identify and de-duplicate"
.. " the event, it supports string templates containing APISIX and"
.. " NGINX variables, like \"req_${request_id}\", which allows you"
.. " to use values returned by upstream services or request-id"
.. " plugin integration",
},
event_subscription_id = {
type = "string",
description = "Event's subscription ID, which is automatically generated or"
.. " specified by you when you assign the plan to the customer on"
.. " Lago, used to associate API consumption to a customer subscription,"
.. " it supports string templates containing APISIX and NGINX variables,"
.. " like \"cus_${consumer_name}\", which allows you to use values"
.. " returned by upstream services or APISIX consumer",
},
event_code = {
type = "string",
description = "Lago billable metric's code for associating an event to a specified"
.. "billable item",
},
event_properties = {
type = "object",
patternProperties = {
[".*"] = {
type = "string",
minLength = 1,
},
},
description = "Event's properties, used to attach information to an event, this"
.. " allows you to send certain information on a event to Lago, such"
.. " as sending HTTP status to take a failed request off the bill, or"
.. " sending the AI token consumption in the response body for accurate"
.. " billing, its keys are fixed strings and its values can be string"
.. " templates containing APISIX and NGINX variables, like \"${status}\""
},
-- connection layer configurations
ssl_verify = {type = "boolean", default = true},
timeout = {
type = "integer",
minimum = 1,
maximum = 60000,
default = 3000,
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
keepalive_timeout = {
type = "integer",
minimum = 1000,
default = 60000,
description = "keepalive timeout in milliseconds",
},
keepalive_pool = {type = "integer", minimum = 1, default = 5},
},
required = {"endpoint_addrs", "token", "event_transaction_id", "event_subscription_id",
"event_code"},
encrypt_fields = {"token"},
}
schema = batch_processor_manager:wrap_schema(schema)
-- According to https://getlago.com/docs/api-reference/events/batch, the maximum batch size is 100,
-- so we have to override the default batch size to make it work out of the boxthe plugin does
-- not set a maximum limit, so if Lago relaxes the limit, then user can modify it
-- to a larger batch size
-- This does not affect other plugins, schema is appended after deep copy
schema.properties.batch_max_size.default = 100
local _M = {
version = 0.1,
priority = 415,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf, schema_type)
local check = {"endpoint_addrs"}
core.utils.check_https(check, conf, plugin_name)
core.utils.check_tls_bool({"ssl_verify"}, conf, plugin_name)
return core.schema.check(schema, conf)
end
local function send_http_data(conf, data)
local body, err = core.json.encode(data)
if not body then
return false, str_format("failed to encode json: %s", err)
end
local params = {
headers = {
["Content-Type"] = "application/json",
["Authorization"] = "Bearer " .. conf.token,
},
keepalive = conf.keepalive,
ssl_verify = conf.ssl_verify,
method = "POST",
body = body,
}
if conf.keepalive then
params.keepalive_timeout = conf.keepalive_timeout
params.keepalive_pool = conf.keepalive_pool
end
local httpc, err = http.new()
if not httpc then
return false, str_format("create http client error: %s", err)
end
httpc:set_timeout(conf.timeout)
-- select an random endpoint and build URL
local endpoint_url = conf.endpoint_addrs[math_random(#conf.endpoint_addrs)]..conf.endpoint_uri
local res, err = httpc:request_uri(endpoint_url, params)
if not res then
return false, err
end
if res.status >= 300 then
return false, str_format("lago api returned status: %d, body: %s",
res.status, res.body or "")
end
return true
end
function _M.log(conf, ctx)
-- build usage event
local event_transaction_id, err = core.utils.resolve_var(conf.event_transaction_id, ctx.var)
if err then
core.log.error("failed to resolve event_transaction_id, event dropped: ", err)
return
end
local event_subscription_id, err = core.utils.resolve_var(conf.event_subscription_id, ctx.var)
if err then
core.log.error("failed to resolve event_subscription_id, event dropped: ", err)
return
end
local entry = {
transaction_id = event_transaction_id,
external_subscription_id = event_subscription_id,
code = conf.event_code,
timestamp = ngx.req.start_time(),
}
if conf.event_properties and type(conf.event_properties) == "table" then
entry.properties = core.table.deepcopy(conf.event_properties)
for key, value in pairs(entry.properties) do
local new_val, err, n_resolved = core.utils.resolve_var(value, ctx.var)
if not err and n_resolved > 0 then
entry.properties[key] = new_val
end
end
end
if batch_processor_manager:add_entry(conf, entry) then
return
end
-- generate a function to be executed by the batch processor
local func = function(entries)
return send_http_data(conf, {
events = entries,
})
end
batch_processor_manager:add_entry_to_new_processor(conf, entry, ctx, func)
end
return _M

View File

@@ -0,0 +1,160 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local ngx = ngx
local ngx_re = require("ngx.re")
local consumer_mod = require("apisix.consumer")
local ldap = require("resty.ldap")
local schema = {
type = "object",
title = "work with route or service object",
properties = {
base_dn = { type = "string" },
ldap_uri = { type = "string" },
use_tls = { type = "boolean", default = false },
tls_verify = { type = "boolean", default = false },
uid = { type = "string", default = "cn" }
},
required = {"base_dn","ldap_uri"},
}
local consumer_schema = {
type = "object",
title = "work with consumer object",
properties = {
user_dn = { type = "string" },
},
required = {"user_dn"},
}
local plugin_name = "ldap-auth"
local _M = {
version = 0.1,
priority = 2540,
type = 'auth',
name = plugin_name,
schema = schema,
consumer_schema = consumer_schema
}
function _M.check_schema(conf, schema_type)
local ok, err
if schema_type == core.schema.TYPE_CONSUMER then
ok, err = core.schema.check(consumer_schema, conf)
else
core.utils.check_tls_bool({"use_tls", "tls_verify"}, conf, plugin_name)
ok, err = core.schema.check(schema, conf)
end
return ok, err
end
local function extract_auth_header(authorization)
local obj = { username = "", password = "" }
local m, err = ngx.re.match(authorization, "Basic\\s(.+)", "jo")
if err then
-- error authorization
return nil, err
end
if not m then
return nil, "Invalid authorization header format"
end
local decoded = ngx.decode_base64(m[1])
if not decoded then
return nil, "Failed to decode authentication header: " .. m[1]
end
local res
res, err = ngx_re.split(decoded, ":")
if err then
return nil, "Split authorization err:" .. err
end
if #res < 2 then
return nil, "Split authorization err: invalid decoded data: " .. decoded
end
obj.username = ngx.re.gsub(res[1], "\\s+", "", "jo")
obj.password = ngx.re.gsub(res[2], "\\s+", "", "jo")
return obj, nil
end
function _M.rewrite(conf, ctx)
core.log.info("plugin rewrite phase, conf: ", core.json.delay_encode(conf))
-- 1. extract authorization from header
local auth_header = core.request.header(ctx, "Authorization")
if not auth_header then
core.response.set_header("WWW-Authenticate", "Basic realm='.'")
return 401, { message = "Missing authorization in request" }
end
local user, err = extract_auth_header(auth_header)
if err or not user then
if err then
core.log.warn(err)
else
core.log.warn("nil user")
end
return 401, { message = "Invalid authorization in request" }
end
-- 2. try authenticate the user against the ldap server
local ldap_host, ldap_port = core.utils.parse_addr(conf.ldap_uri)
local ldapconf = {
timeout = 10000,
start_tls = false,
ldap_host = ldap_host,
ldap_port = ldap_port or 389,
ldaps = conf.use_tls,
tls_verify = conf.tls_verify,
base_dn = conf.base_dn,
attribute = conf.uid,
keepalive = 60000,
}
local res, err = ldap.ldap_authenticate(user.username, user.password, ldapconf)
if not res then
core.log.warn("ldap-auth failed: ", err)
return 401, { message = "Invalid user authorization" }
end
local user_dn = conf.uid .. "=" .. user.username .. "," .. conf.base_dn
-- 3. Retrieve consumer for authorization plugin
local consumer_conf = consumer_mod.plugin(plugin_name)
if not consumer_conf then
return 401, { message = "Missing related consumer" }
end
local consumers = consumer_mod.consumers_kv(plugin_name, consumer_conf, "user_dn")
local consumer = consumers[user_dn]
if not consumer then
return 401, {message = "Invalid user authorization"}
end
consumer_mod.attach_consumer(ctx, consumer, consumer_conf)
core.log.info("hit basic-auth access")
end
return _M

View File

@@ -0,0 +1,94 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local limit_conn = require("apisix.plugins.limit-conn.init")
local redis_schema = require("apisix.utils.redis-schema")
local policy_to_additional_properties = redis_schema.schema
local plugin_name = "limit-conn"
local schema = {
type = "object",
properties = {
conn = {type = "integer", exclusiveMinimum = 0}, -- limit.conn max
burst = {type = "integer", minimum = 0},
default_conn_delay = {type = "number", exclusiveMinimum = 0},
only_use_default_delay = {type = "boolean", default = false},
key = {type = "string"},
key_type = {type = "string",
enum = {"var", "var_combination"},
default = "var",
},
policy = {
type = "string",
enum = {"redis", "redis-cluster", "local"},
default = "local",
},
rejected_code = {
type = "integer", minimum = 200, maximum = 599, default = 503
},
rejected_msg = {
type = "string", minLength = 1
},
allow_degradation = {type = "boolean", default = false}
},
required = {"conn", "burst", "default_conn_delay", "key"},
["if"] = {
properties = {
policy = {
enum = {"redis"},
},
},
},
["then"] = policy_to_additional_properties.redis,
["else"] = {
["if"] = {
properties = {
policy = {
enum = {"redis-cluster"},
},
},
},
["then"] = policy_to_additional_properties["redis-cluster"],
}
}
local _M = {
version = 0.1,
priority = 1003,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
return core.schema.check(schema, conf)
end
function _M.access(conf, ctx)
return limit_conn.increase(conf, ctx)
end
function _M.log(conf, ctx)
return limit_conn.decrease(conf, ctx)
end
return _M

View File

@@ -0,0 +1,171 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local limit_conn_new = require("resty.limit.conn").new
local core = require("apisix.core")
local is_http = ngx.config.subsystem == "http"
local sleep = core.sleep
local shdict_name = "plugin-limit-conn"
if ngx.config.subsystem == "stream" then
shdict_name = shdict_name .. "-stream"
end
local redis_single_new
local redis_cluster_new
do
local redis_src = "apisix.plugins.limit-conn.limit-conn-redis"
redis_single_new = require(redis_src).new
local cluster_src = "apisix.plugins.limit-conn.limit-conn-redis-cluster"
redis_cluster_new = require(cluster_src).new
end
local lrucache = core.lrucache.new({
type = "plugin",
})
local _M = {}
local function create_limit_obj(conf)
if conf.policy == "local" then
core.log.info("create new limit-conn plugin instance")
return limit_conn_new(shdict_name, conf.conn, conf.burst,
conf.default_conn_delay)
elseif conf.policy == "redis" then
core.log.info("create new limit-conn redis plugin instance")
return redis_single_new("plugin-limit-conn", conf, conf.conn, conf.burst,
conf.default_conn_delay)
elseif conf.policy == "redis-cluster" then
core.log.info("create new limit-conn redis-cluster plugin instance")
return redis_cluster_new("plugin-limit-conn", conf, conf.conn, conf.burst,
conf.default_conn_delay)
else
return nil, "policy enum not match"
end
end
function _M.increase(conf, ctx)
core.log.info("ver: ", ctx.conf_version)
local lim, err = lrucache(conf, nil, create_limit_obj, conf)
if not lim then
core.log.error("failed to instantiate a resty.limit.conn object: ", err)
if conf.allow_degradation then
return
end
return 500
end
local conf_key = conf.key
local key
if conf.key_type == "var_combination" then
local err, n_resolved
key, err, n_resolved = core.utils.resolve_var(conf_key, ctx.var)
if err then
core.log.error("could not resolve vars in ", conf_key, " error: ", err)
end
if n_resolved == 0 then
key = nil
end
else
key = ctx.var[conf_key]
end
if key == nil then
core.log.info("The value of the configured key is empty, use client IP instead")
-- When the value of key is empty, use client IP instead
key = ctx.var["remote_addr"]
end
key = key .. ctx.conf_type .. ctx.conf_version
core.log.info("limit key: ", key)
local delay, err = lim:incoming(key, true)
if not delay then
if err == "rejected" then
if conf.rejected_msg then
return conf.rejected_code, { error_msg = conf.rejected_msg }
end
return conf.rejected_code or 503
end
core.log.error("failed to limit conn: ", err)
if conf.allow_degradation then
return
end
return 500
end
if lim:is_committed() then
if not ctx.limit_conn then
ctx.limit_conn = core.tablepool.fetch("plugin#limit-conn", 0, 6)
end
core.table.insert_tail(ctx.limit_conn, lim, key, delay, conf.only_use_default_delay)
end
if delay >= 0.001 then
sleep(delay)
end
end
function _M.decrease(conf, ctx)
local limit_conn = ctx.limit_conn
if not limit_conn then
return
end
for i = 1, #limit_conn, 4 do
local lim = limit_conn[i]
local key = limit_conn[i + 1]
local delay = limit_conn[i + 2]
local use_delay = limit_conn[i + 3]
local latency
if is_http then
if not use_delay then
if ctx.proxy_passed then
latency = ctx.var.upstream_response_time
else
latency = ctx.var.request_time - delay
end
end
end
core.log.debug("request latency is ", latency) -- for test
local conn, err = lim:leaving(key, latency)
if not conn then
core.log.error("failed to record the connection leaving request: ",
err)
break
end
end
core.tablepool.release("plugin#limit-conn", limit_conn)
ctx.limit_conn = nil
return
end
return _M

View File

@@ -0,0 +1,78 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local redis_cluster = require("apisix.utils.rediscluster")
local core = require("apisix.core")
local util = require("apisix.plugins.limit-conn.util")
local setmetatable = setmetatable
local ngx_timer_at = ngx.timer.at
local _M = {version = 0.1}
local mt = {
__index = _M
}
function _M.new(plugin_name, conf, max, burst, default_conn_delay)
local red_cli, err = redis_cluster.new(conf, "plugin-limit-conn-redis-cluster-slot-lock")
if not red_cli then
return nil, err
end
local self = {
conf = conf,
plugin_name = plugin_name,
burst = burst,
max = max + 0, -- just to ensure the param is good
unit_delay = default_conn_delay,
red_cli = red_cli,
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, commit)
return util.incoming(self, self.red_cli, key, commit)
end
function _M.is_committed(self)
return self.committed
end
local function leaving_thread(premature, self, key, req_latency)
return util.leaving(self, self.red_cli, key, req_latency)
end
function _M.leaving(self, key, req_latency)
-- log_by_lua can't use cosocket
local ok, err = ngx_timer_at(0, leaving_thread, self, key, req_latency)
if not ok then
core.log.error("failed to create timer: ", err)
return nil, err
end
return ok
end
return _M

View File

@@ -0,0 +1,85 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local redis = require("apisix.utils.redis")
local core = require("apisix.core")
local util = require("apisix.plugins.limit-conn.util")
local ngx_timer_at = ngx.timer.at
local setmetatable = setmetatable
local _M = {version = 0.1}
local mt = {
__index = _M
}
function _M.new(plugin_name, conf, max, burst, default_conn_delay)
local self = {
conf = conf,
plugin_name = plugin_name,
burst = burst,
max = max + 0, -- just to ensure the param is good
unit_delay = default_conn_delay,
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, commit)
local conf = self.conf
local red, err = redis.new(conf)
if not red then
return red, err
end
return util.incoming(self, red, key, commit)
end
function _M.is_committed(self)
return self.committed
end
local function leaving_thread(premature, self, key, req_latency)
local conf = self.conf
local red, err = redis.new(conf)
if not red then
return red, err
end
return util.leaving(self, red, key, req_latency)
end
function _M.leaving(self, key, req_latency)
-- log_by_lua can't use cosocket
local ok, err = ngx_timer_at(0, leaving_thread, self, key, req_latency)
if not ok then
core.log.error("failed to create timer: ", err)
return nil, err
end
return ok
end
return _M

View File

@@ -0,0 +1,81 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local assert = assert
local math = require "math"
local floor = math.floor
local _M = {version = 0.3}
function _M.incoming(self, red, key, commit)
local max = self.max
self.committed = false
key = "limit_conn" .. ":" .. key
local conn, err
if commit then
conn, err = red:incrby(key, 1)
if not conn then
return nil, err
end
if conn > max + self.burst then
conn, err = red:incrby(key, -1)
if not conn then
return nil, err
end
return nil, "rejected"
end
self.committed = true
else
local conn_from_red, err = red:get(key)
if err then
return nil, err
end
conn = (conn_from_red or 0) + 1
end
if conn > max then
-- make the excessive connections wait
return self.unit_delay * floor((conn - 1) / max), conn
end
-- we return a 0 delay by default
return 0, conn
end
function _M.leaving(self, red, key, req_latency)
assert(key)
key = "limit_conn" .. ":" .. key
local conn, err = red:incrby(key, -1)
if not conn then
return nil, err
end
if req_latency then
local unit_delay = self.unit_delay
self.unit_delay = (req_latency + unit_delay) / 2
end
return conn
end
return _M

View File

@@ -0,0 +1,51 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local fetch_secrets = require("apisix.secret").fetch_secrets
local limit_count = require("apisix.plugins.limit-count.init")
local workflow = require("apisix.plugins.workflow")
local plugin_name = "limit-count"
local _M = {
version = 0.5,
priority = 1002,
name = plugin_name,
schema = limit_count.schema,
metadata_schema = limit_count.metadata_schema,
}
function _M.check_schema(conf, schema_type)
return limit_count.check_schema(conf, schema_type)
end
function _M.access(conf, ctx)
conf = fetch_secrets(conf, true, conf, "")
return limit_count.rate_limit(conf, ctx, plugin_name, 1)
end
function _M.workflow_handler()
workflow.register(plugin_name,
function (conf, ctx)
return limit_count.rate_limit(conf, ctx, plugin_name, 1)
end,
function (conf)
return limit_count.check_schema(conf)
end)
end
return _M

View File

@@ -0,0 +1,332 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local apisix_plugin = require("apisix.plugin")
local tab_insert = table.insert
local ipairs = ipairs
local pairs = pairs
local redis_schema = require("apisix.utils.redis-schema")
local policy_to_additional_properties = redis_schema.schema
local get_phase = ngx.get_phase
local limit_redis_cluster_new
local limit_redis_new
local limit_local_new
do
local local_src = "apisix.plugins.limit-count.limit-count-local"
limit_local_new = require(local_src).new
local redis_src = "apisix.plugins.limit-count.limit-count-redis"
limit_redis_new = require(redis_src).new
local cluster_src = "apisix.plugins.limit-count.limit-count-redis-cluster"
limit_redis_cluster_new = require(cluster_src).new
end
local lrucache = core.lrucache.new({
type = 'plugin', serial_creating = true,
})
local group_conf_lru = core.lrucache.new({
type = 'plugin',
})
local metadata_defaults = {
limit_header = "X-RateLimit-Limit",
remaining_header = "X-RateLimit-Remaining",
reset_header = "X-RateLimit-Reset",
}
local metadata_schema = {
type = "object",
properties = {
limit_header = {
type = "string",
default = metadata_defaults.limit_header,
},
remaining_header = {
type = "string",
default = metadata_defaults.remaining_header,
},
reset_header = {
type = "string",
default = metadata_defaults.reset_header,
},
},
}
local schema = {
type = "object",
properties = {
count = {type = "integer", exclusiveMinimum = 0},
time_window = {type = "integer", exclusiveMinimum = 0},
group = {type = "string"},
key = {type = "string", default = "remote_addr"},
key_type = {type = "string",
enum = {"var", "var_combination", "constant"},
default = "var",
},
rejected_code = {
type = "integer", minimum = 200, maximum = 599, default = 503
},
rejected_msg = {
type = "string", minLength = 1
},
policy = {
type = "string",
enum = {"local", "redis", "redis-cluster"},
default = "local",
},
allow_degradation = {type = "boolean", default = false},
show_limit_quota_header = {type = "boolean", default = true}
},
required = {"count", "time_window"},
["if"] = {
properties = {
policy = {
enum = {"redis"},
},
},
},
["then"] = policy_to_additional_properties.redis,
["else"] = {
["if"] = {
properties = {
policy = {
enum = {"redis-cluster"},
},
},
},
["then"] = policy_to_additional_properties["redis-cluster"],
}
}
local schema_copy = core.table.deepcopy(schema)
local _M = {
schema = schema,
metadata_schema = metadata_schema,
}
local function group_conf(conf)
return conf
end
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
if conf.group then
-- means that call by some plugin not support
if conf._vid then
return false, "group is not supported"
end
local fields = {}
-- When the goup field is configured,
-- we will use schema_copy to get the whitelist of properties,
-- so that we can avoid getting injected properties.
for k in pairs(schema_copy.properties) do
tab_insert(fields, k)
end
local extra = policy_to_additional_properties[conf.policy]
if extra then
for k in pairs(extra.properties) do
tab_insert(fields, k)
end
end
local prev_conf = group_conf_lru(conf.group, "", group_conf, conf)
for _, field in ipairs(fields) do
if not core.table.deep_eq(prev_conf[field], conf[field]) then
core.log.error("previous limit-conn group ", prev_conf.group,
" conf: ", core.json.encode(prev_conf))
core.log.error("current limit-conn group ", conf.group,
" conf: ", core.json.encode(conf))
return false, "group conf mismatched"
end
end
end
return true
end
local function create_limit_obj(conf, plugin_name)
core.log.info("create new " .. plugin_name .. " plugin instance")
if not conf.policy or conf.policy == "local" then
return limit_local_new("plugin-" .. plugin_name, conf.count,
conf.time_window)
end
if conf.policy == "redis" then
return limit_redis_new("plugin-" .. plugin_name,
conf.count, conf.time_window, conf)
end
if conf.policy == "redis-cluster" then
return limit_redis_cluster_new("plugin-" .. plugin_name, conf.count,
conf.time_window, conf)
end
return nil
end
local function gen_limit_key(conf, ctx, key)
if conf.group then
return conf.group .. ':' .. key
end
-- here we add a separator ':' to mark the boundary of the prefix and the key itself
-- Here we use plugin-level conf version to prevent the counter from being resetting
-- because of the change elsewhere.
-- A route which reuses a previous route's ID will inherits its counter.
local conf_type = ctx.conf_type_without_consumer or ctx.conf_type
local conf_id = ctx.conf_id_without_consumer or ctx.conf_id
local new_key = conf_type .. conf_id .. ':' .. apisix_plugin.conf_version(conf)
.. ':' .. key
if conf._vid then
-- conf has _vid means it's from workflow plugin, add _vid to the key
-- so that the counter is unique per action.
return new_key .. ':' .. conf._vid
end
return new_key
end
local function gen_limit_obj(conf, ctx, plugin_name)
if conf.group then
return lrucache(conf.group, "", create_limit_obj, conf, plugin_name)
end
local extra_key
if conf._vid then
extra_key = conf.policy .. '#' .. conf._vid
else
extra_key = conf.policy
end
return core.lrucache.plugin_ctx(lrucache, ctx, extra_key, create_limit_obj, conf, plugin_name)
end
function _M.rate_limit(conf, ctx, name, cost, dry_run)
core.log.info("ver: ", ctx.conf_version)
core.log.info("conf: ", core.json.delay_encode(conf, true))
local lim, err = gen_limit_obj(conf, ctx, name)
if not lim then
core.log.error("failed to fetch limit.count object: ", err)
if conf.allow_degradation then
return
end
return 500
end
local conf_key = conf.key
local key
if conf.key_type == "var_combination" then
local err, n_resolved
key, err, n_resolved = core.utils.resolve_var(conf_key, ctx.var)
if err then
core.log.error("could not resolve vars in ", conf_key, " error: ", err)
end
if n_resolved == 0 then
key = nil
end
elseif conf.key_type == "constant" then
key = conf_key
else
key = ctx.var[conf_key]
end
if key == nil then
core.log.info("The value of the configured key is empty, use client IP instead")
-- When the value of key is empty, use client IP instead
key = ctx.var["remote_addr"]
end
key = gen_limit_key(conf, ctx, key)
core.log.info("limit key: ", key)
local delay, remaining, reset
if not conf.policy or conf.policy == "local" then
delay, remaining, reset = lim:incoming(key, not dry_run, conf, cost)
else
delay, remaining, reset = lim:incoming(key, cost)
end
local metadata = apisix_plugin.plugin_metadata("limit-count")
if metadata then
metadata = metadata.value
else
metadata = metadata_defaults
end
core.log.info("limit-count plugin-metadata: ", core.json.delay_encode(metadata))
local set_limit_headers = {
limit_header = conf.limit_header or metadata.limit_header,
remaining_header = conf.remaining_header or metadata.remaining_header,
reset_header = conf.reset_header or metadata.reset_header,
}
local phase = get_phase()
local set_header = phase ~= "log"
if not delay then
local err = remaining
if err == "rejected" then
-- show count limit header when rejected
if conf.show_limit_quota_header and set_header then
core.response.set_header(set_limit_headers.limit_header, conf.count,
set_limit_headers.remaining_header, 0,
set_limit_headers.reset_header, reset)
end
if conf.rejected_msg then
return conf.rejected_code, { error_msg = conf.rejected_msg }
end
return conf.rejected_code
end
core.log.error("failed to limit count: ", err)
if conf.allow_degradation then
return
end
return 500, {error_msg = "failed to limit count"}
end
if conf.show_limit_quota_header and set_header then
core.response.set_header(set_limit_headers.limit_header, conf.count,
set_limit_headers.remaining_header, remaining,
set_limit_headers.reset_header, reset)
end
end
return _M

View File

@@ -0,0 +1,79 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local limit_count = require("resty.limit.count")
local ngx = ngx
local ngx_time = ngx.time
local assert = assert
local setmetatable = setmetatable
local core = require("apisix.core")
local _M = {}
local mt = {
__index = _M
}
local function set_endtime(self, key, time_window)
-- set an end time
local end_time = ngx_time() + time_window
-- save to dict by key
local success, err = self.dict:set(key, end_time, time_window)
if not success then
core.log.error("dict set key ", key, " error: ", err)
end
local reset = time_window
return reset
end
local function read_reset(self, key)
-- read from dict
local end_time = (self.dict:get(key) or 0)
local reset = end_time - ngx_time()
if reset < 0 then
reset = 0
end
return reset
end
function _M.new(plugin_name, limit, window)
assert(limit > 0 and window > 0)
local self = {
limit_count = limit_count.new(plugin_name, limit, window),
dict = ngx.shared[plugin_name .. "-reset-header"]
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, commit, conf, cost)
local delay, remaining = self.limit_count:incoming(key, commit, cost)
local reset
if remaining == conf.count - cost then
reset = set_endtime(self, key, conf.time_window)
else
reset = read_reset(self, key)
end
return delay, remaining, reset
end
return _M

View File

@@ -0,0 +1,83 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local redis_cluster = require("apisix.utils.rediscluster")
local core = require("apisix.core")
local setmetatable = setmetatable
local tostring = tostring
local _M = {}
local mt = {
__index = _M
}
local script = core.string.compress_script([=[
assert(tonumber(ARGV[3]) >= 1, "cost must be at least 1")
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1] - ARGV[3], 'EX', ARGV[2])
return {ARGV[1] - ARGV[3], ARGV[2]}
end
return {redis.call('incrby', KEYS[1], 0 - ARGV[3]), ttl}
]=])
function _M.new(plugin_name, limit, window, conf)
local red_cli, err = redis_cluster.new(conf, "plugin-limit-count-redis-cluster-slot-lock")
if not red_cli then
return nil, err
end
local self = {
limit = limit,
window = window,
conf = conf,
plugin_name = plugin_name,
red_cli = red_cli,
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, cost)
local red = self.red_cli
local limit = self.limit
local window = self.window
key = self.plugin_name .. tostring(key)
local ttl = 0
local res, err = red:eval(script, 1, key, limit, window, cost or 1)
if err then
return nil, err, ttl
end
local remaining = res[1]
ttl = res[2]
if remaining < 0 then
return nil, "rejected", ttl
end
return 0, remaining, ttl
end
return _M

View File

@@ -0,0 +1,89 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local redis = require("apisix.utils.redis")
local core = require("apisix.core")
local assert = assert
local setmetatable = setmetatable
local tostring = tostring
local _M = {version = 0.3}
local mt = {
__index = _M
}
local script = core.string.compress_script([=[
assert(tonumber(ARGV[3]) >= 1, "cost must be at least 1")
local ttl = redis.call('ttl', KEYS[1])
if ttl < 0 then
redis.call('set', KEYS[1], ARGV[1] - ARGV[3], 'EX', ARGV[2])
return {ARGV[1] - ARGV[3], ARGV[2]}
end
return {redis.call('incrby', KEYS[1], 0 - ARGV[3]), ttl}
]=])
function _M.new(plugin_name, limit, window, conf)
assert(limit > 0 and window > 0)
local self = {
limit = limit,
window = window,
conf = conf,
plugin_name = plugin_name,
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, cost)
local conf = self.conf
local red, err = redis.new(conf)
if not red then
return red, err, 0
end
local limit = self.limit
local window = self.window
local res
key = self.plugin_name .. tostring(key)
local ttl = 0
res, err = red:eval(script, 1, key, limit, window, cost or 1)
if err then
return nil, err, ttl
end
local remaining = res[1]
ttl = res[2]
local ok, err = red:set_keepalive(10000, 100)
if not ok then
return nil, err, ttl
end
if remaining < 0 then
return nil, "rejected", ttl
end
return 0, remaining, ttl
end
return _M

View File

@@ -0,0 +1,183 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local limit_req_new = require("resty.limit.req").new
local core = require("apisix.core")
local redis_schema = require("apisix.utils.redis-schema")
local policy_to_additional_properties = redis_schema.schema
local plugin_name = "limit-req"
local sleep = core.sleep
local redis_single_new
local redis_cluster_new
do
local redis_src = "apisix.plugins.limit-req.limit-req-redis"
redis_single_new = require(redis_src).new
local cluster_src = "apisix.plugins.limit-req.limit-req-redis-cluster"
redis_cluster_new = require(cluster_src).new
end
local lrucache = core.lrucache.new({
type = "plugin",
})
local schema = {
type = "object",
properties = {
rate = {type = "number", exclusiveMinimum = 0},
burst = {type = "number", minimum = 0},
key = {type = "string"},
key_type = {type = "string",
enum = {"var", "var_combination"},
default = "var",
},
policy = {
type = "string",
enum = {"redis", "redis-cluster", "local"},
default = "local",
},
rejected_code = {
type = "integer", minimum = 200, maximum = 599, default = 503
},
rejected_msg = {
type = "string", minLength = 1
},
nodelay = {
type = "boolean", default = false
},
allow_degradation = {type = "boolean", default = false}
},
required = {"rate", "burst", "key"},
["if"] = {
properties = {
policy = {
enum = {"redis"},
},
},
},
["then"] = policy_to_additional_properties.redis,
["else"] = {
["if"] = {
properties = {
policy = {
enum = {"redis-cluster"},
},
},
},
["then"] = policy_to_additional_properties["redis-cluster"],
}
}
local _M = {
version = 0.1,
priority = 1001,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
return true
end
local function create_limit_obj(conf)
if conf.policy == "local" then
core.log.info("create new limit-req plugin instance")
return limit_req_new("plugin-limit-req", conf.rate, conf.burst)
elseif conf.policy == "redis" then
core.log.info("create new limit-req redis plugin instance")
return redis_single_new("plugin-limit-req", conf, conf.rate, conf.burst)
elseif conf.policy == "redis-cluster" then
core.log.info("create new limit-req redis-cluster plugin instance")
return redis_cluster_new("plugin-limit-req", conf, conf.rate, conf.burst)
else
return nil, "policy enum not match"
end
end
function _M.access(conf, ctx)
local lim, err = core.lrucache.plugin_ctx(lrucache, ctx, nil,
create_limit_obj, conf)
if not lim then
core.log.error("failed to instantiate a resty.limit.req object: ", err)
if conf.allow_degradation then
return
end
return 500
end
local conf_key = conf.key
local key
if conf.key_type == "var_combination" then
local err, n_resolved
key, err, n_resolved = core.utils.resolve_var(conf_key, ctx.var)
if err then
core.log.error("could not resolve vars in ", conf_key, " error: ", err)
end
if n_resolved == 0 then
key = nil
end
else
key = ctx.var[conf_key]
end
if key == nil then
core.log.info("The value of the configured key is empty, use client IP instead")
-- When the value of key is empty, use client IP instead
key = ctx.var["remote_addr"]
end
key = key .. ctx.conf_type .. ctx.conf_version
core.log.info("limit key: ", key)
local delay, err = lim:incoming(key, true)
if not delay then
if err == "rejected" then
if conf.rejected_msg then
return conf.rejected_code, { error_msg = conf.rejected_msg }
end
return conf.rejected_code
end
core.log.error("failed to limit req: ", err)
if conf.allow_degradation then
return
end
return 500
end
if delay >= 0.001 and not conf.nodelay then
sleep(delay)
end
end
return _M

View File

@@ -0,0 +1,50 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local redis_cluster = require("apisix.utils.rediscluster")
local setmetatable = setmetatable
local util = require("apisix.plugins.limit-req.util")
local _M = {version = 0.1}
local mt = {
__index = _M
}
function _M.new(plugin_name, conf, rate, burst)
local red_cli, err = redis_cluster.new(conf, "plugin-limit-req-redis-cluster-slot-lock")
if not red_cli then
return nil, err
end
local self = {
conf = conf,
plugin_name = plugin_name,
burst = burst * 1000,
rate = rate * 1000,
red_cli = red_cli,
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, commit)
return util.incoming(self, self.red_cli, key, commit)
end
return _M

View File

@@ -0,0 +1,54 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local redis = require("apisix.utils.redis")
local setmetatable = setmetatable
local util = require("apisix.plugins.limit-req.util")
local setmetatable = setmetatable
local _M = {version = 0.1}
local mt = {
__index = _M
}
function _M.new(plugin_name, conf, rate, burst)
local self = {
conf = conf,
plugin_name = plugin_name,
burst = burst * 1000,
rate = rate * 1000,
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, commit)
local conf = self.conf
local red, err = redis.new(conf)
if not red then
return red, err
end
return util.incoming(self, red, key, commit)
end
return _M

View File

@@ -0,0 +1,78 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local math = require "math"
local abs = math.abs
local max = math.max
local ngx_now = ngx.now
local ngx_null = ngx.null
local tonumber = tonumber
local _M = {version = 0.1}
-- the "commit" argument controls whether should we record the event in shm.
function _M.incoming(self, red, key, commit)
local rate = self.rate
local now = ngx_now() * 1000
key = "limit_req" .. ":" .. key
local excess_key = key .. "excess"
local last_key = key .. "last"
local excess, err = red:get(excess_key)
if err then
return nil, err
end
local last, err = red:get(last_key)
if err then
return nil, err
end
if excess ~= ngx_null and last ~= ngx_null then
excess = tonumber(excess)
last = tonumber(last)
local elapsed = now - last
excess = max(excess - rate * abs(elapsed) / 1000 + 1000, 0)
if excess > self.burst then
return nil, "rejected"
end
else
excess = 0
end
if commit then
local ok
local err
ok, err = red:set(excess_key, excess)
if not ok then
return nil, err
end
ok, err = red:set(last_key, now)
if not ok then
return nil, err
end
end
-- return the delay in seconds, as well as excess
return excess / rate, excess / 1000
end
return _M

View File

@@ -0,0 +1,327 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local timers = require("apisix.timers")
local plugin = require("apisix.plugin")
local process = require("ngx.process")
local signal = require("resty.signal")
local shell = require("resty.shell")
local ipairs = ipairs
local ngx = ngx
local ngx_time = ngx.time
local ngx_update_time = ngx.update_time
local lfs = require("lfs")
local type = type
local io_open = io.open
local os_date = os.date
local os_remove = os.remove
local os_rename = os.rename
local str_sub = string.sub
local str_format = string.format
local str_byte = string.byte
local ngx_sleep = require("apisix.core.utils").sleep
local string_rfind = require("pl.stringx").rfind
local local_conf
local plugin_name = "log-rotate"
local INTERVAL = 60 * 60 -- rotate interval (unit: second)
local MAX_KEPT = 24 * 7 -- max number of log files will be kept
local MAX_SIZE = -1 -- max size of file will be rotated
local COMPRESSION_FILE_SUFFIX = ".tar.gz" -- compression file suffix
local rotate_time
local default_logs
local enable_compression = false
local DEFAULT_ACCESS_LOG_FILENAME = "access.log"
local DEFAULT_ERROR_LOG_FILENAME = "error.log"
local SLASH_BYTE = str_byte("/")
local schema = {
type = "object",
properties = {},
}
local _M = {
version = 0.1,
priority = 100,
name = plugin_name,
schema = schema,
scope = "global",
}
local function file_exists(path)
local file = io_open(path, "r")
if file then
file:close()
end
return file ~= nil
end
local function get_log_path_info(file_type)
local_conf = core.config.local_conf()
local conf_path
if file_type == "error.log" then
conf_path = local_conf and local_conf.nginx_config and
local_conf.nginx_config.error_log
else
conf_path = local_conf and local_conf.nginx_config and
local_conf.nginx_config.http and
local_conf.nginx_config.http.access_log
end
local prefix = ngx.config.prefix()
if conf_path then
-- relative path
if str_byte(conf_path) ~= SLASH_BYTE then
conf_path = prefix .. conf_path
end
local n = string_rfind(conf_path, "/")
if n ~= nil and n ~= #conf_path then
local dir = str_sub(conf_path, 1, n)
local name = str_sub(conf_path, n + 1)
return dir, name
end
end
return prefix .. "logs/", file_type
end
local function tab_sort_comp(a, b)
return a > b
end
local function scan_log_folder(log_file_name)
local t = {}
local log_dir, log_name = get_log_path_info(log_file_name)
local compression_log_type = log_name .. COMPRESSION_FILE_SUFFIX
for file in lfs.dir(log_dir) do
local n = string_rfind(file, "__")
if n ~= nil then
local log_type = file:sub(n + 2)
if log_type == log_name or log_type == compression_log_type then
core.table.insert(t, file)
end
end
end
core.table.sort(t, tab_sort_comp)
return t, log_dir
end
local function rename_file(log, date_str)
local new_file
if not log.new_file then
core.log.warn(log.type, " is off")
return
end
new_file = str_format(log.new_file, date_str)
if file_exists(new_file) then
core.log.info("file exist: ", new_file)
return new_file
end
local ok, err = os_rename(log.file, new_file)
if not ok then
core.log.error("move file from ", log.file, " to ", new_file,
" res:", ok, " msg:", err)
return
end
return new_file
end
local function compression_file(new_file, timeout)
if not new_file or type(new_file) ~= "string" then
core.log.info("compression file: ", new_file, " invalid")
return
end
local n = string_rfind(new_file, "/")
local new_filepath = str_sub(new_file, 1, n)
local new_filename = str_sub(new_file, n + 1)
local com_filename = new_filename .. COMPRESSION_FILE_SUFFIX
local cmd = str_format("cd %s && tar -zcf %s %s", new_filepath,
com_filename, new_filename)
core.log.info("log file compress command: " .. cmd)
local ok, stdout, stderr, reason, status = shell.run(cmd, nil, timeout, nil)
if not ok then
core.log.error("compress log file from ", new_filename, " to ", com_filename,
" fail, stdout: ", stdout, " stderr: ", stderr, " reason: ", reason,
" status: ", status)
return
end
ok, stderr = os_remove(new_file)
if stderr then
core.log.error("remove uncompressed log file: ", new_file,
" fail, err: ", stderr, " res:", ok)
end
end
local function init_default_logs(logs_info, log_type)
local filepath, filename = get_log_path_info(log_type)
logs_info[log_type] = { type = log_type }
if filename ~= "off" then
logs_info[log_type].file = filepath .. filename
logs_info[log_type].new_file = filepath .. "/%s__" .. filename
end
end
local function file_size(file)
local attr = lfs.attributes(file)
if attr then
return attr.size
end
return 0
end
local function rotate_file(files, now_time, max_kept, timeout)
if core.table.isempty(files) then
return
end
local new_files = core.table.new(2, 0)
-- rename the log files
for _, file in ipairs(files) do
local now_date = os_date("%Y-%m-%d_%H-%M-%S", now_time)
local new_file = rename_file(default_logs[file], now_date)
if not new_file then
return
end
core.table.insert(new_files, new_file)
end
-- send signal to reopen log files
local pid = process.get_master_pid()
core.log.warn("send USR1 signal to master process [", pid, "] for reopening log file")
local ok, err = signal.kill(pid, signal.signum("USR1"))
if not ok then
core.log.error("failed to send USR1 signal for reopening log file: ", err)
end
if enable_compression then
-- Waiting for nginx reopen files
-- to avoid losing logs during compression
ngx_sleep(0.5)
for _, new_file in ipairs(new_files) do
compression_file(new_file, timeout)
end
end
for _, file in ipairs(files) do
-- clean the oldest file
local log_list, log_dir = scan_log_folder(file)
for i = max_kept + 1, #log_list do
local path = log_dir .. log_list[i]
local ok, err = os_remove(path)
if err then
core.log.error("remove old log file: ", path, " err: ", err, " res:", ok)
end
end
end
end
local function rotate()
local interval = INTERVAL
local max_kept = MAX_KEPT
local max_size = MAX_SIZE
local attr = plugin.plugin_attr(plugin_name)
local timeout = 10000 -- default timeout 10 seconds
if attr then
interval = attr.interval or interval
max_kept = attr.max_kept or max_kept
max_size = attr.max_size or max_size
timeout = attr.timeout or timeout
enable_compression = attr.enable_compression or enable_compression
end
core.log.info("rotate interval:", interval)
core.log.info("rotate max keep:", max_kept)
core.log.info("rotate max size:", max_size)
core.log.info("rotate timeout:", timeout)
if not default_logs then
-- first init default log filepath and filename
default_logs = {}
init_default_logs(default_logs, DEFAULT_ACCESS_LOG_FILENAME)
init_default_logs(default_logs, DEFAULT_ERROR_LOG_FILENAME)
end
ngx_update_time()
local now_time = ngx_time()
if not rotate_time then
-- first init rotate time
rotate_time = now_time + interval - (now_time % interval)
core.log.info("first init rotate time is: ", rotate_time)
return
end
if now_time >= rotate_time then
local files = {DEFAULT_ACCESS_LOG_FILENAME, DEFAULT_ERROR_LOG_FILENAME}
rotate_file(files, now_time, max_kept, timeout)
-- reset rotate time
rotate_time = rotate_time + interval
elseif max_size > 0 then
local access_log_file_size = file_size(default_logs[DEFAULT_ACCESS_LOG_FILENAME].file)
local error_log_file_size = file_size(default_logs[DEFAULT_ERROR_LOG_FILENAME].file)
local files = core.table.new(2, 0)
if access_log_file_size >= max_size then
core.table.insert(files, DEFAULT_ACCESS_LOG_FILENAME)
end
if error_log_file_size >= max_size then
core.table.insert(files, DEFAULT_ERROR_LOG_FILENAME)
end
rotate_file(files, now_time, max_kept, timeout)
end
end
function _M.init()
timers.register_timer("plugin#log-rotate", rotate, true)
end
function _M.destroy()
timers.unregister_timer("plugin#log-rotate", true)
end
return _M

View File

@@ -0,0 +1,351 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local plugin = require("apisix.plugin")
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local log_util = require("apisix.utils.log-util")
local path = require("pl.path")
local http = require("resty.http")
local ngx = ngx
local tostring = tostring
local pairs = pairs
local tab_concat = table.concat
local udp = ngx.socket.udp
local plugin_name = "loggly"
local batch_processor_manager = bp_manager_mod.new(plugin_name)
local severity = {
EMEGR = 0, -- system is unusable
ALERT = 1, -- action must be taken immediately
CRIT = 2, -- critical conditions
ERR = 3, -- error conditions
WARNING = 4, -- warning conditions
NOTICE = 5, -- normal but significant condition
INFO = 6, -- informational
DEBUG = 7, -- debug-level messages
}
local severity_enums = {}
do
for k, _ in pairs(severity) do
severity_enums[#severity_enums+1] = k
severity_enums[#severity_enums+1] = k:lower()
end
end
local schema = {
type = "object",
properties = {
customer_token = {type = "string"},
severity = {
type = "string",
default = "INFO",
enum = severity_enums,
description = "base severity log level",
},
include_req_body = {type = "boolean", default = false},
include_req_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
include_resp_body = {type = "boolean", default = false},
include_resp_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
tags = {
type = "array",
minItems = 1,
items = {
type = "string",
-- we prevent of having `tag=` prefix
pattern = "^(?!tag=)[ -~]*",
},
default = {"apisix"}
},
ssl_verify = {
-- applicable for https protocol
type = "boolean",
default = true
},
log_format = {type = "object"},
severity_map = {
type = "object",
description = "upstream response code vs syslog severity mapping",
patternProperties = {
["^[1-5][0-9]{2}$"] = {
description = "keys are HTTP status code, values are severity",
type = "string",
enum = severity_enums
},
},
additionalProperties = false
}
},
required = {"customer_token"}
}
local defaults = {
host = "logs-01.loggly.com",
port = 514,
protocol = "syslog",
timeout = 5000
}
local metadata_schema = {
type = "object",
properties = {
host = {
type = "string",
default = defaults.host
},
port = {
type = "integer",
default = defaults.port
},
protocol = {
type = "string",
default = defaults.protocol,
-- in case of http and https, we use bulk endpoints
enum = {"syslog", "http", "https"}
},
timeout = {
type = "integer",
minimum = 1,
default= defaults.timeout
},
log_format = {
type = "object",
}
}
}
local _M = {
version = 0.1,
priority = 411,
name = plugin_name,
schema = batch_processor_manager:wrap_schema(schema),
metadata_schema = metadata_schema
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local ok, err = core.schema.check(schema, conf)
if not ok then
return nil, err
end
if conf.severity_map then
local cache = {}
for k, v in pairs(conf.severity_map) do
cache[k] = severity[v:upper()]
end
conf._severity_cache = cache
end
return log_util.check_log_schema(conf)
end
function _M.body_filter(conf, ctx)
log_util.collect_body(conf, ctx)
end
local function generate_log_message(conf, ctx)
local entry = log_util.get_log_entry(plugin_name, conf, ctx)
local json_str, err = core.json.encode(entry)
if not json_str then
core.log.error('error occurred while encoding the data: ', err)
return nil
end
local metadata = plugin.plugin_metadata(plugin_name)
if metadata and metadata.value.protocol ~= "syslog" then
return json_str
end
-- generate rfc5424 compliant syslog event
local timestamp = log_util.get_rfc3339_zulu_timestamp()
local taglist = {}
if conf.tags then
for i = 1, #conf.tags do
core.table.insert(taglist, "tag=\"" .. conf.tags[i] .. "\"")
end
end
local message_severity = severity[conf.severity:upper()]
if conf._severity_cache and conf._severity_cache[tostring(ngx.status)] then
message_severity = conf._severity_cache[tostring(ngx.status)]
end
local message = {
-- facility LOG_USER - random user level message
"<".. tostring(8 + message_severity) .. ">1",-- <PRIVAL>1
timestamp, -- timestamp
ctx.var.host or "-", -- hostname
"apisix", -- appname
ctx.var.pid, -- proc-id
"-", -- msgid
"[" .. conf.customer_token .. "@41058 " .. tab_concat(taglist, " ") .. "]",
json_str
}
return tab_concat(message, " ")
end
local function send_data_over_udp(message, metadata)
local err_msg
local res = true
local sock = udp()
local host, port = metadata.value.host, metadata.value.port
sock:settimeout(metadata.value.timeout)
local ok, err = sock:setpeername(host, port)
if not ok then
core.log.error("failed to send log: ", err)
return false, "failed to connect to UDP server: host[" .. host
.. "] port[" .. tostring(port) .. "] err: " .. err
end
ok, err = sock:send(message)
if not ok then
res = false
core.log.error("failed to send log: ", err)
err_msg = "failed to send data to UDP server: host[" .. host
.. "] port[" .. tostring(port) .. "] err:" .. err
end
ok, err = sock:close()
if not ok then
core.log.error("failed to close the UDP connection, host[",
host, "] port[", port, "] ", err)
end
return res, err_msg
end
local function send_bulk_over_http(message, metadata, conf)
local endpoint = path.join(metadata.value.host, "bulk", conf.customer_token, "tag", "bulk")
local has_prefix = core.string.has_prefix(metadata.value.host, "http")
if not has_prefix then
if metadata.value.protocol == "http" then
endpoint = "http://" .. endpoint
else
endpoint = "https://" .. endpoint
end
end
local httpc = http.new()
httpc:set_timeout(metadata.value.timeout)
local res, err = httpc:request_uri(endpoint, {
ssl_verify = conf.ssl_verify,
method = "POST",
body = message,
headers = {
["Content-Type"] = "application/json",
["X-LOGGLY-TAG"] = conf.tags
},
})
if not res then
return false, "failed to write log to loggly, " .. err
end
if res.status ~= 200 then
local body = core.json.decode(res.body)
if not body then
return false, "failed to send log to loggly, http status code: " .. res.status
else
return false, "failed to send log to loggly, http status code: " .. res.status
.. " response body: ".. res.body
end
end
return true
end
local handle_http_payload
local function handle_log(entries)
local metadata = plugin.plugin_metadata(plugin_name)
core.log.info("metadata: ", core.json.delay_encode(metadata))
if not metadata then
core.log.info("received nil metadata: using metadata defaults: ",
core.json.delay_encode(defaults, true))
metadata = {}
metadata.value = defaults
end
core.log.info("sending a batch logs to ", metadata.value.host)
if metadata.value.protocol == "syslog" then
for i = 1, #entries do
local ok, err = send_data_over_udp(entries[i], metadata)
if not ok then
return false, err, i
end
end
else
return handle_http_payload(entries, metadata)
end
return true
end
function _M.log(conf, ctx)
local log_data = generate_log_message(conf, ctx)
if not log_data then
return
end
handle_http_payload = function (entries, metadata)
-- loggly bulk endpoint expects entries concatenated in newline("\n")
local message = tab_concat(entries, "\n")
return send_bulk_over_http(message, metadata, conf)
end
if batch_processor_manager:add_entry(conf, log_data) then
return
end
batch_processor_manager:add_entry_to_new_processor(conf, log_data, ctx, handle_log)
end
return _M

View File

@@ -0,0 +1,251 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local bp_manager_mod = require("apisix.utils.batch-processor-manager")
local log_util = require("apisix.utils.log-util")
local core = require("apisix.core")
local http = require("resty.http")
local new_tab = require("table.new")
local pairs = pairs
local ipairs = ipairs
local tostring = tostring
local math_random = math.random
local table_insert = table.insert
local ngx = ngx
local str_format = core.string.format
local plugin_name = "loki-logger"
local batch_processor_manager = bp_manager_mod.new("loki logger")
local schema = {
type = "object",
properties = {
-- core configurations
endpoint_addrs = {
type = "array",
minItems = 1,
items = core.schema.uri_def,
},
endpoint_uri = {
type = "string",
minLength = 1,
default = "/loki/api/v1/push"
},
tenant_id = {type = "string", default = "fake"},
headers = {
type = "object",
patternProperties = {
[".*"] = {
type = "string",
minLength = 1,
},
},
},
log_labels = {
type = "object",
patternProperties = {
[".*"] = {
type = "string",
minLength = 1,
},
},
default = {
job = "apisix",
},
},
-- connection layer configurations
ssl_verify = {type = "boolean", default = false},
timeout = {
type = "integer",
minimum = 1,
maximum = 60000,
default = 3000,
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
keepalive_timeout = {
type = "integer",
minimum = 1000,
default = 60000,
description = "keepalive timeout in milliseconds",
},
keepalive_pool = {type = "integer", minimum = 1, default = 5},
-- logger related configurations
log_format = {type = "object"},
include_req_body = {type = "boolean", default = false},
include_req_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
include_resp_body = {type = "boolean", default = false},
include_resp_body_expr = {
type = "array",
minItems = 1,
items = {
type = "array"
}
},
},
required = {"endpoint_addrs"}
}
local metadata_schema = {
type = "object",
properties = {
log_format = {
type = "object"
}
},
}
local _M = {
version = 0.1,
priority = 414,
name = plugin_name,
schema = batch_processor_manager:wrap_schema(schema),
metadata_schema = metadata_schema,
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local check = {"endpoint_addrs"}
core.utils.check_https(check, conf, plugin_name)
core.utils.check_tls_bool({"ssl_verify"}, conf, plugin_name)
local ok, err = core.schema.check(schema, conf)
if not ok then
return nil, err
end
return log_util.check_log_schema(conf)
end
local function send_http_data(conf, log)
local headers = conf.headers or {}
headers = core.table.clone(headers)
headers["X-Scope-OrgID"] = conf.tenant_id
headers["Content-Type"] = "application/json"
local params = {
headers = headers,
keepalive = conf.keepalive,
ssl_verify = conf.ssl_verify,
method = "POST",
body = core.json.encode(log)
}
if conf.keepalive then
params.keepalive_timeout = conf.keepalive_timeout
params.keepalive_pool = conf.keepalive_pool
end
local httpc, err = http.new()
if not httpc then
return false, str_format("create http client error: %s", err)
end
httpc:set_timeout(conf.timeout)
-- select an random endpoint and build URL
local endpoint_url = conf.endpoint_addrs[math_random(#conf.endpoint_addrs)] .. conf.endpoint_uri
local res, err = httpc:request_uri(endpoint_url, params)
if not res then
return false, err
end
if res.status >= 300 then
return false, str_format("loki server returned status: %d, body: %s",
res.status, res.body or "")
end
return true
end
function _M.body_filter(conf, ctx)
log_util.collect_body(conf, ctx)
end
function _M.log(conf, ctx)
local entry = log_util.get_log_entry(plugin_name, conf, ctx)
if not entry.route_id then
entry.route_id = "no-matched"
end
-- insert start time as log time, multiply to nanoseconds
-- use string concat to circumvent 64bit integers that LuaVM cannot handle
-- that is, first process the decimal part of the millisecond value
-- and then add 6 zeros by string concatenation
entry.loki_log_time = tostring(ngx.req.start_time() * 1000) .. "000000"
if batch_processor_manager:add_entry(conf, entry) then
return
end
local labels = conf.log_labels
-- parsing possible variables in label value
for key, value in pairs(labels) do
local new_val, err, n_resolved = core.utils.resolve_var(value, ctx.var)
if not err and n_resolved > 0 then
labels[key] = new_val
end
end
-- generate a function to be executed by the batch processor
local func = function(entries)
-- build loki request data
local data = {
streams = {
{
stream = labels,
values = new_tab(1, 0),
}
}
}
-- add all entries to the batch
for _, entry in ipairs(entries) do
local log_time = entry.loki_log_time
entry.loki_log_time = nil -- clean logger internal field
table_insert(data.streams[1].values, {
log_time, core.json.encode(entry)
})
end
return send_http_data(conf, data)
end
batch_processor_manager:add_entry_to_new_processor(conf, entry, ctx, func)
end
return _M

View File

@@ -0,0 +1,173 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local unpack = unpack
local ngx = ngx
local thread_spawn = ngx.thread.spawn
local thread_kill = ngx.thread.kill
local worker_exiting = ngx.worker.exiting
local resty_signal = require("resty.signal")
local core = require("apisix.core")
local pipe = require("ngx.pipe")
local mcp_server_wrapper = require("apisix.plugins.mcp.server_wrapper")
local schema = {
type = "object",
properties = {
base_uri = {
type = "string",
minLength = 1,
default = "",
},
command = {
type = "string",
minLength = 1,
},
args = {
type = "array",
items = {
type = "string",
},
minItems = 0,
},
},
required = {
"command"
},
}
local plugin_name = "mcp-bridge"
local _M = {
version = 0.1,
priority = 510,
name = plugin_name,
schema = schema,
}
function _M.check_schema(conf, schema_type)
return core.schema.check(schema, conf)
end
local function on_connect(conf, ctx)
return function(additional)
local proc, err = pipe.spawn({conf.command, unpack(conf.args or {})})
if not proc then
core.log.error("failed to spawn mcp process: ", err)
return 500
end
proc:set_timeouts(nil, 100, 100)
ctx.mcp_bridge_proc = proc
local server = additional.server
-- ngx_pipe is a yield operation, so we no longer need
-- to explicitly yield to other threads by ngx_sleep
ctx.mcp_bridge_proc_event_loop = thread_spawn(function ()
local stdout_partial, stderr_partial, need_exit
while not worker_exiting() do
-- read all the messages in stdout's pipe, line by line
-- if there is an incomplete message it is buffered and
-- spliced before the next message
repeat
local line, _
line, _, stdout_partial = proc:stdout_read_line()
if line then
local ok, err = server.transport:send(
stdout_partial and stdout_partial .. line or line
)
if not ok then
core.log.info("session ", server.session_id,
" exit, failed to send response message: ", err)
need_exit = true
break
end
stdout_partial = nil -- luacheck: ignore
end
until not line
if need_exit then
break
end
repeat
local line, _
line, _, stderr_partial = proc:stderr_read_line()
if line then
local ok, err = server.transport:send(
'{"jsonrpc":"2.0","method":"notifications/stderr","params":{"content":"'
.. (stderr_partial and stderr_partial .. line or line) .. '"}}')
if not ok then
core.log.info("session ", server.session_id,
" exit, failed to send response message: ", err)
need_exit = true
break
end
stderr_partial = "" -- luacheck: ignore
end
until not line
if need_exit then
break
end
end
end)
end
end
local function on_client_message(conf, ctx)
return function(message, additional)
core.log.info("session ", additional.server.session_id,
" send message to mcp server: ", additional.raw)
ctx.mcp_bridge_proc:write(additional.raw .. "\n")
end
end
local function on_disconnect(conf, ctx)
return function()
if ctx.mcp_bridge_proc_event_loop then
thread_kill(ctx.mcp_bridge_proc_event_loop)
ctx.mcp_bridge_proc_event_loop = nil
end
local proc = ctx.mcp_bridge_proc
if proc then
proc:shutdown("stdin")
proc:wait()
local _, err = proc:wait() -- check if process not exited then kill it
if err ~= "exited" then
proc:kill(resty_signal.signum("KILL") or 9)
end
end
end
end
function _M.access(conf, ctx)
return mcp_server_wrapper.access(conf, ctx, {
event_handler = {
on_connect = on_connect(conf, ctx),
on_client_message = on_client_message(conf, ctx),
on_disconnect = on_disconnect(conf, ctx),
},
})
end
return _M

View File

@@ -0,0 +1,90 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local type = type
local setmetatable = setmetatable
local ngx = ngx
local ngx_sleep = ngx.sleep
local thread_spawn = ngx.thread.spawn
local thread_kill = ngx.thread.kill
local worker_exiting = ngx.worker.exiting
local shared_dict = ngx.shared["mcp-session"] -- TODO: rename to something like mcp-broker
local core = require("apisix.core")
local broker_utils = require("apisix.plugins.mcp.broker.utils")
local _M = {}
local mt = { __index = _M }
local STORAGE_SUFFIX_QUEUE = ":queue"
function _M.new(opts)
return setmetatable({
session_id = opts.session_id,
event_handler = {}
}, mt)
end
function _M.on(self, event, cb)
self.event_handler[event] = cb
end
function _M.push(self, message)
if not message then
return nil, "message is nil"
end
local ok, err = shared_dict:rpush(self.session_id .. STORAGE_SUFFIX_QUEUE, message)
if not ok then
return nil, "failed to push message to queue: " .. err
end
return true
end
function _M.start(self)
self.thread = thread_spawn(function()
while not worker_exiting() do
local item, err = shared_dict:lpop(self.session_id .. STORAGE_SUFFIX_QUEUE)
if err then
core.log.info("session ", self.session_id,
" exit, failed to pop message from queue: ", err)
break
end
if item and type(item) == "string"
and type(self.event_handler[broker_utils.EVENT_MESSAGE]) == "function" then
self.event_handler[broker_utils.EVENT_MESSAGE](
core.json.decode(item), { raw = item }
)
end
ngx_sleep(0.1) -- yield to other light threads
end
end)
end
function _M.close(self)
if self.thread then
thread_kill(self.thread)
self.thread = nil
end
end
return _M

View File

@@ -0,0 +1,21 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local _M = {}
_M.EVENT_MESSAGE = "message"
return _M

View File

@@ -0,0 +1,116 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local require = require
local setmetatable = setmetatable
local ngx = ngx
local ngx_sleep = ngx.sleep
local thread_spwan = ngx.thread.spawn
local thread_wait = ngx.thread.wait
local thread_kill = ngx.thread.kill
local worker_exiting = ngx.worker.exiting
local core = require("apisix.core")
local broker_utils = require("apisix.plugins.mcp.broker.utils")
local _M = {}
local mt = { __index = _M }
_M.EVENT_CLIENT_MESSAGE = "event:client_message"
-- TODO: ping requester and handler
function _M.new(opts)
local session_id = opts.session_id or core.id.gen_uuid_v4()
-- TODO: configurable broker type
local message_broker = require("apisix.plugins.mcp.broker.shared_dict").new({
session_id = session_id,
})
-- TODO: configurable transport type
local transport = require("apisix.plugins.mcp.transport.sse").new()
local obj = setmetatable({
opts = opts,
session_id = session_id,
next_ping_id = 0,
transport = transport,
message_broker = message_broker,
event_handler = {},
need_exit = false,
}, mt)
message_broker:on(broker_utils.EVENT_MESSAGE, function (message, additional)
if obj.event_handler[_M.EVENT_CLIENT_MESSAGE] then
obj.event_handler[_M.EVENT_CLIENT_MESSAGE](message, additional)
end
end)
return obj
end
function _M.on(self, event, cb)
self.event_handler[event] = cb
end
function _M.start(self)
self.message_broker:start()
-- ping loop
local ping = thread_spwan(function()
while not worker_exiting() do
if self.need_exit then
break
end
self.next_ping_id = self.next_ping_id + 1
local ok, err = self.transport:send(
'{"jsonrpc": "2.0","method": "ping","id":"ping:' .. self.next_ping_id .. '"}')
if not ok then
core.log.info("session ", self.session_id,
" exit, failed to send ping message: ", err)
self.need_exit = true
break
end
ngx_sleep(30)
end
end)
thread_wait(ping)
thread_kill(ping)
end
function _M.close(self)
if self.message_broker then
self.message_broker:close()
end
end
function _M.push_message(self, message)
local ok, err = self.message_broker:push(message)
if not ok then
return nil, "failed to push message to broker: " .. err
end
return true
end
return _M

View File

@@ -0,0 +1,106 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local ngx = ngx
local ngx_exit = ngx.exit
local re_match = ngx.re.match
local core = require("apisix.core")
local mcp_server = require("apisix.plugins.mcp.server")
local _M = {}
local V241105_ENDPOINT_SSE = "sse"
local V241105_ENDPOINT_MESSAGE = "message"
local function sse_handler(conf, ctx, opts)
-- send SSE headers and first chunk
core.response.set_header("Content-Type", "text/event-stream")
core.response.set_header("Cache-Control", "no-cache")
local server = opts.server
-- send endpoint event to advertise the message endpoint
server.transport:send(conf.base_uri .. "/message?sessionId=" .. server.session_id, "endpoint")
if opts.event_handler and opts.event_handler.on_client_message then
server:on(mcp_server.EVENT_CLIENT_MESSAGE, function(message, additional)
additional.server = server
opts.event_handler.on_client_message(message, additional)
end)
end
if opts.event_handler and opts.event_handler.on_connect then
local code, body = opts.event_handler.on_connect({ server = server })
if code then
return code, body
end
server:start() -- this is a sync call that only returns when the client disconnects
end
if opts.event_handler.on_disconnect then
opts.event_handler.on_disconnect({ server = server })
server:close()
end
ngx_exit(0) -- exit current phase, skip the upstream module
end
local function message_handler(conf, ctx, opts)
local body = core.request.get_body(nil, ctx)
if not body then
return 400
end
local ok, err = opts.server:push_message(body)
if not ok then
core.log.error("failed to add task to queue: ", err)
return 500
end
return 202
end
function _M.access(conf, ctx, opts)
local m, err = re_match(ctx.var.uri, "^" .. conf.base_uri .. "/(.*)", "jo")
if err then
core.log.info("failed to mcp base uri: ", err)
return core.response.exit(404)
end
local action = m and m[1] or false
if not action then
return core.response.exit(404)
end
if action == V241105_ENDPOINT_SSE and core.request.get_method() == "GET" then
opts.server = mcp_server.new({})
return sse_handler(conf, ctx, opts)
end
if action == V241105_ENDPOINT_MESSAGE and core.request.get_method() == "POST" then
-- TODO: check ctx.var.arg_sessionId
-- recover server instead of create
opts.server = mcp_server.new({ session_id = ctx.var.arg_sessionId })
return core.response.exit(message_handler(conf, ctx, opts))
end
return core.response.exit(404)
end
return _M

View File

@@ -0,0 +1,44 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local setmetatable = setmetatable
local type = type
local ngx = ngx
local ngx_print = ngx.print
local ngx_flush = ngx.flush
local core = require("apisix.core")
local _M = {}
local mt = { __index = _M }
function _M.new()
return setmetatable({}, mt)
end
function _M.send(self, message, event_type)
local data = type(message) == "table" and core.json.encode(message) or message
local ok, err = ngx_print("event: " .. (event_type or "message") ..
"\ndata: " .. data .. "\n\n")
if not ok then
return ok, "failed to write buffer: " .. err
end
return ngx_flush(true)
end
return _M

View File

@@ -0,0 +1,243 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local xml2lua = require("xml2lua")
local json = core.json
local math = math
local ngx = ngx
local ngx_re = ngx.re
local pairs = pairs
local string = string
local table = table
local type = type
local support_content_type = {
["application/xml"] = true,
["application/json"] = true,
["text/plain"] = true,
["text/html"] = true,
["text/xml"] = true
}
local schema = {
type = "object",
properties = {
-- specify response delay time,default 0ms
delay = { type = "integer", default = 0 },
-- specify response status,default 200
response_status = { type = "integer", default = 200, minimum = 100 },
-- specify response content type, support application/xml, text/plain
-- and application/json, default application/json
content_type = { type = "string", default = "application/json;charset=utf8" },
-- specify response body.
response_example = { type = "string" },
-- specify response json schema, if response_example is not nil, this conf will be ignore.
-- generate random response by json schema.
response_schema = { type = "object" },
with_mock_header = { type = "boolean", default = true },
response_headers = {
type = "object",
minProperties = 1,
patternProperties = {
["^[^:]+$"] = {
oneOf = {
{ type = "string" },
{ type = "number" }
}
}
},
}
},
anyOf = {
{ required = { "response_example" } },
{ required = { "response_schema" } }
}
}
local _M = {
version = 0.1,
priority = 10900,
name = "mocking",
schema = schema,
}
local function parse_content_type(content_type)
if not content_type then
return ""
end
local m = ngx_re.match(content_type, "([ -~]*);([ -~]*)", "jo")
if m and #m == 2 then
return m[1], m[2]
end
return content_type
end
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
local typ = parse_content_type(conf.content_type)
if not support_content_type[typ] then
return false, "unsupported content type!"
end
return true
end
local function gen_string(example)
if example and type(example) == "string" then
return example
end
local n = math.random(1, 10)
local list = {}
for i = 1, n do
table.insert(list, string.char(math.random(97, 122)))
end
return table.concat(list)
end
local function gen_number(example)
if example and type(example) == "number" then
return example
end
return math.random() * 10000
end
local function gen_integer(example)
if example and type(example) == "number" then
return math.floor(example)
end
return math.random(1, 10000)
end
local function gen_boolean(example)
if example and type(example) == "boolean" then
return example
end
local r = math.random(0, 1)
if r == 0 then
return false
end
return true
end
local gen_array, gen_object, gen_by_property
function gen_array(property)
local output = {}
if property.items == nil then
return nil
end
local v = property.items
local n = math.random(1, 3)
for i = 1, n do
table.insert(output, gen_by_property(v))
end
return output
end
function gen_object(property)
local output = {}
if not property.properties then
return output
end
for k, v in pairs(property.properties) do
output[k] = gen_by_property(v)
end
return output
end
function gen_by_property(property)
local typ = string.lower(property.type)
local example = property.example
if typ == "array" then
return gen_array(property)
end
if typ == "object" then
return gen_object(property)
end
if typ == "string" then
return gen_string(example)
end
if typ == "number" then
return gen_number(example)
end
if typ == "integer" then
return gen_integer(example)
end
if typ == "boolean" then
return gen_boolean(example)
end
return nil
end
function _M.access(conf, ctx)
local response_content = ""
if conf.response_example then
response_content = conf.response_example
else
local output = gen_object(conf.response_schema)
local typ = parse_content_type(conf.content_type)
if typ == "application/xml" or typ == "text/xml" then
response_content = xml2lua.toXml(output, "data")
elseif typ == "application/json" or typ == "text/plain" then
response_content = json.encode(output)
else
core.log.error("json schema body only support xml and json content type")
end
end
ngx.header["Content-Type"] = conf.content_type
if conf.with_mock_header then
ngx.header["x-mock-by"] = "APISIX/" .. core.version.VERSION
end
if conf.response_headers then
for key, value in pairs(conf.response_headers) do
value = core.utils.resolve_var(value, ctx.var)
core.response.add_header(key, value)
end
end
if conf.delay > 0 then
ngx.sleep(conf.delay)
end
return conf.response_status, core.utils.resolve_var(response_content, ctx.var)
end
return _M

View File

@@ -0,0 +1,105 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local require = require
local pairs = pairs
local type = type
local plugin = require("apisix.plugin")
local schema = {
type = "object",
title = "work with route or service object",
properties = {
auth_plugins = { type = "array", minItems = 2 }
},
required = { "auth_plugins" },
}
local plugin_name = "multi-auth"
local _M = {
version = 0.1,
priority = 2600,
type = 'auth',
name = plugin_name,
schema = schema
}
function _M.check_schema(conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
local auth_plugins = conf.auth_plugins
for k, auth_plugin in pairs(auth_plugins) do
for auth_plugin_name, auth_plugin_conf in pairs(auth_plugin) do
local auth = plugin.get(auth_plugin_name)
if auth == nil then
return false, auth_plugin_name .. " plugin did not found"
else
if auth.type ~= 'auth' then
return false, auth_plugin_name .. " plugin is not supported"
end
local ok, err = auth.check_schema(auth_plugin_conf, auth.schema)
if not ok then
return false, "plugin " .. auth_plugin_name .. " check schema failed: " .. err
end
end
end
end
return true
end
function _M.rewrite(conf, ctx)
local auth_plugins = conf.auth_plugins
local status_code
local errors = {}
for k, auth_plugin in pairs(auth_plugins) do
for auth_plugin_name, auth_plugin_conf in pairs(auth_plugin) do
local auth = plugin.get(auth_plugin_name)
-- returns 401 HTTP status code if authentication failed, otherwise returns nothing.
local auth_code, err = auth.rewrite(auth_plugin_conf, ctx)
if type(err) == "table" then
err = err.message -- compat
end
status_code = auth_code
if auth_code == nil then
core.log.debug(auth_plugin_name .. " succeed to authenticate the request")
goto authenticated
else
core.table.insert(errors, auth_plugin_name ..
" failed to authenticate the request, code: "
.. auth_code .. ". error: " .. err)
end
end
end
:: authenticated ::
if status_code ~= nil then
for _, error in pairs(errors) do
core.log.warn(error)
end
return 401, { message = "Authorization Failed" }
end
end
return _M

Some files were not shown because too many files have changed in this diff Show More