Skip to content

Commit

Permalink
Extract query ID from all kill_query procedure variations
Browse files Browse the repository at this point in the history
  • Loading branch information
willmostly committed Sep 3, 2024
1 parent 5afabe0 commit 95e782a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@ public final class ProxyUtils
* capitalization.
*/
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");

/**
* This regular expression extracts the query id from a CALL system.runtime.kill_query procedure call. It extracts the first string between quotes
* following the open parentheses after system.runtime.kill_query. The pattern handles the optional named arguments, and optional message argument
* as well as arbitrary whitespace.
*/
private static final Pattern KILL_QUERY_PROCEDURE_PATTERN
= Pattern.compile(".*system\\.runtime\\.kill_query\\s*\\(\\s*(query_id\\s*=>)?\\s*'([^\\\\s]+?)'(,\\s*(message\\s*=>\\s*)?('.*'))?\\)");

private ProxyUtils() {}

Expand Down Expand Up @@ -97,18 +104,9 @@ public static String extractQueryIdIfPresent(HttpServletRequest request)
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
if (!isNullOrEmpty(queryText)
&& queryText.toLowerCase().contains("system.runtime.kill_query")) {
// extract and return the queryId
String[] parts = queryText.split(",");
for (String part : parts) {
if (part.contains("query_id")) {
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
if (matcher.find()) {
String queryQuoted = matcher.group();
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
return queryQuoted.substring(1, queryQuoted.length() - 1);
}
}
}
Matcher matcher = KILL_QUERY_PROCEDURE_PATTERN.matcher(queryText.toLowerCase());
if (matcher.find()) {
return matcher.group(2);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@
*/
package io.trino.gateway.ha.handler;

import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.mockito.Mockito;

import java.io.ByteArrayInputStream;
import java.io.IOException;

import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.AUTHORIZATION;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.when;

@TestInstance(Lifecycle.PER_CLASS)
public class TestQueryIdCachingProxyHandler
Expand Down Expand Up @@ -61,6 +66,78 @@ public void testExtractQueryIdFromUrl()
.isNull();
}

@Test
void testQueryIdFromKill()
throws IOException
{
assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')")))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')")))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', '''20200416_160256_03078_6b4yt''')")))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', 'kill_query(''20200416_160256_03078_6b4yt'')')")))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(
extractQueryIdIfPresent(
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_6b4yt')")))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id=>'20200416_160256_03078_6b4yt')")))
.isEqualTo("20200416_160256_03078_6b4yt");

assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt')")))
.isEqualTo("20200416_160256_03078_6b4yt");
}

private static HttpServletRequest prepareMockRequestWithBody(String query)
throws IOException
{
HttpServletRequest request = Mockito.mock(HttpServletRequest.class);

ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8));
when(request.getInputStream()).thenReturn(new ServletInputStream()
{
@Override
public boolean isFinished()
{
return byteArrayInputStream.available() > 0;
}

@Override
public boolean isReady()
{
return true;
}

@Override
public void setReadListener(ReadListener readListener)
{}

public int read()
throws IOException
{
return byteArrayInputStream.read();
}
});

when(request.getQueryString()).thenReturn("");

return request;
}

@Test
public void testUserFromRequest()
throws IOException
Expand Down

0 comments on commit 95e782a

Please sign in to comment.