Skip to content

Commit

Permalink
Extract query ID from all kill_query procedure variations
Browse files Browse the repository at this point in the history
  • Loading branch information
willmostly committed Sep 23, 2024
1 parent 23e8320 commit e805660
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import com.google.common.base.Splitter;
import com.google.common.io.CharStreams;
import io.airlift.log.Logger;
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import io.trino.gateway.ha.router.TrinoQueryProperties;
import jakarta.servlet.http.HttpServletRequest;

import java.io.InputStreamReader;
Expand Down Expand Up @@ -50,7 +52,6 @@ public final class ProxyUtils
* capitalization.
*/
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");

private ProxyUtils() {}

Expand Down Expand Up @@ -89,26 +90,19 @@ public static String getQueryUser(String userHeader, String authorization)
return parts.get(0);
}

public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths, RequestAnalyzerConfig requestAnalyzerConfig)
{
String path = request.getRequestURI();
String queryParams = request.getQueryString();
try {
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
if (!isNullOrEmpty(queryText)
&& queryText.toLowerCase().contains("system.runtime.kill_query")) {
// extract and return the queryId
String[] parts = queryText.split(",");
for (String part : parts) {
if (part.contains("query_id")) {
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
if (matcher.find()) {
String queryQuoted = matcher.group();
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
return queryQuoted.substring(1, queryQuoted.length() - 1);
}
}
}
&& queryText.toLowerCase().contains("kill_query")) {
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig);
if (trinoQueryProperties.getProcedure().isPresent()
&& trinoQueryProperties.getProcedure().orElseThrow().getName().getParts().getLast().equalsIgnoreCase("kill_query")) {
return trinoQueryProperties.getProcedureArgs().getFirst().getValue().toString()
.replaceFirst("^'", "").replaceFirst("'$", "");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import io.airlift.log.Logger;
import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import io.trino.gateway.ha.router.GatewayCookie;
import io.trino.gateway.ha.router.RoutingGroupSelector;
import io.trino.gateway.ha.router.RoutingManager;
Expand Down Expand Up @@ -45,18 +46,22 @@ public class RoutingTargetHandler
private final RoutingGroupSelector routingGroupSelector;
private final List<String> statementPaths;
private final List<Pattern> extraWhitelistPaths;
private final RequestAnalyzerConfig requestAnalyzerConfig;
private final boolean cookiesEnabled;

public RoutingTargetHandler(
RoutingManager routingManager,
RoutingGroupSelector routingGroupSelector,
List<String> statementPaths,
List<String> extraWhitelistPaths)
List<String> extraWhitelistPaths,
RequestAnalyzerConfig requestAnalyzerConfig)
{
this.routingManager = requireNonNull(routingManager);

this.routingGroupSelector = requireNonNull(routingGroupSelector);
this.statementPaths = requireNonNull(statementPaths);
this.extraWhitelistPaths = extraWhitelistPaths.stream().map(Pattern::compile).collect(toImmutableList());
this.requestAnalyzerConfig = requestAnalyzerConfig;
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
}

Expand Down Expand Up @@ -94,7 +99,7 @@ private String getBackendFromRoutingGroup(HttpServletRequest request)

private Optional<String> getPreviousBackend(HttpServletRequest request)
{
String queryId = extractQueryIdIfPresent(request, statementPaths);
String queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyzerConfig);
if (!isNullOrEmpty(queryId)) {
return Optional.of(routingManager.findBackendForQueryId(queryId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ public RoutingTargetHandler getRoutingTargetHandler(
routingManager,
routingGroupSelector,
configuration.getStatementPaths(),
configuration.getExtraWhitelistPaths());
configuration.getExtraWhitelistPaths(),
configuration.getRequestAnalyzerConfig());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.compress.zstd.ZstdDecompressor;
Expand All @@ -29,6 +30,8 @@
import io.trino.sql.parser.SqlParser;
import io.trino.sql.tree.AddColumn;
import io.trino.sql.tree.Analyze;
import io.trino.sql.tree.Call;
import io.trino.sql.tree.CallArgument;
import io.trino.sql.tree.CreateCatalog;
import io.trino.sql.tree.CreateMaterializedView;
import io.trino.sql.tree.CreateSchema;
Expand Down Expand Up @@ -94,6 +97,8 @@ public class TrinoQueryProperties
private Set<String> catalogs = ImmutableSet.of();
private Set<String> schemas = ImmutableSet.of();
private Set<String> catalogSchemas = ImmutableSet.of();
private Optional<Call> procedure = Optional.empty();
private List<CallArgument> procedureArgs = ImmutableList.of();
private boolean isNewQuerySubmission;
private boolean isQueryParsingSuccessful;

Expand Down Expand Up @@ -262,6 +267,11 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
ImmutableSet.Builder<String> catalogSchemaBuilder)
throws RequestParsingException
{
if (node instanceof Call) {
procedure = Optional.of((Call) node);
procedureArgs = ((Call) node).getArguments();
return;
}
switch (node) {
case AddColumn s -> tableBuilder.add(qualifyName(s.getName()));
case Analyze s -> tableBuilder.add(qualifyName(s.getTableName()));
Expand Down Expand Up @@ -513,6 +523,16 @@ public boolean isQueryParsingSuccessful()
return isQueryParsingSuccessful;
}

public Optional<Call> getProcedure()
{
return procedure;
}

public List<CallArgument> getProcedureArgs()
{
return procedureArgs;
}

public static class AlternateStatementRequestBodyFormat
{
// Based on https://github.com/trinodb/trino/wiki/trino-v2-client-protocol, without session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,29 @@
package io.trino.gateway.ha.handler;

import com.google.common.collect.ImmutableList;
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.HttpMethod;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.mockito.Mockito;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.StringReader;
import java.util.List;

import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.AUTHORIZATION;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.when;

@TestInstance(Lifecycle.PER_CLASS)
public class TestQueryIdCachingProxyHandler
Expand Down Expand Up @@ -66,6 +75,186 @@ public void testExtractQueryIdFromUrl()
.isNull();
}

@Test
void testQueryIdFromKill()
throws IOException
{
RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
requestAnalyzerConfig.setAnalyzeRequest(true);
assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')"),
ImmutableList.of(), requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query(Query_id => '20200416_160256_03078_6b4yt', Message => 'If he dies, he dies')"),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')"),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')"),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', 'kill_query(''20200416_160256_03078_7n5uy'')')"),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')"),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id=>'20200416_160256_03078_6b4yt')"),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt')"),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL kill_query('20200416_160256_03078_6b4yt')"), ImmutableList.of(), requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("call Kill_Query('20200416_160256_03078_6b4yt')"), ImmutableList.of(), requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("SELECT * FROM postgres.query_logs.queries WHERE sql LIKE '%kill_query(''20200416_160256%' "),
ImmutableList.of(),
requestAnalyzerConfig))
.isNull();

assertThat(extractQueryIdIfPresent(
prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql like '%kill_query(''20200416_160256_03078_6b4yt' "),
ImmutableList.of(),
requestAnalyzerConfig))
.isNull();

assertThat(extractQueryIdIfPresent(
prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql LIKE 'CALL kill_query(_20200416_160256_03078_6b4yt_)' "),
ImmutableList.of(),
requestAnalyzerConfig))
.isNull();

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("""
--CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')
SELECT 1
"""),
ImmutableList.of(),
requestAnalyzerConfig))
.isNull();

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("""
/*
CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')
*/
SELECT 1
"""),
ImmutableList.of(),
requestAnalyzerConfig))
.isNull();

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("""
CALL KILL_QUERY('20200416_160256_03078_6b4yt', 'If he dies, he dies')
"""),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("""
CALL KILL_QUERY ('20200416_160256_03078_6b4yt', 'If he dies, he dies')
"""),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("""
CALL
KILL_QUERY
(
-- this is a comment
'20200416_160256_03078_6b4yt' --this is a trailing comment
,
/*
this is
a multiline comment
*/
'If he dies, he dies
')
"""),
ImmutableList.of(),
requestAnalyzerConfig))
.isEqualTo("20200416_160256_03078_6b4yt");
}

private static HttpServletRequest prepareMockRequestWithBody(String query)
throws IOException
{
HttpServletRequest request = Mockito.mock(HttpServletRequest.class);

ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8));
when(request.getMethod()).thenReturn(HttpMethod.POST);
when(request.getInputStream()).thenReturn(new ServletInputStream()
{
@Override
public boolean isFinished()
{
return byteArrayInputStream.available() > 0;
}

@Override
public boolean isReady()
{
return true;
}

@Override
public void setReadListener(ReadListener readListener)
{}

public int read()
throws IOException
{
return byteArrayInputStream.read();
}
});

when(request.getReader()).thenReturn(new BufferedReader(new StringReader(query)));

when(request.getQueryString()).thenReturn("");

return request;
}

@Test
public void testUserFromRequest()
throws IOException
Expand Down

0 comments on commit e805660

Please sign in to comment.