diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBPreparedStatement.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBPreparedStatement.java index c92b6549bf9d..032b9769df36 100644 --- a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBPreparedStatement.java +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBPreparedStatement.java @@ -19,7 +19,16 @@ package org.apache.iotdb.jdbc; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.rpc.RpcUtils; +import org.apache.iotdb.rpc.StatementExecutionException; +import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.service.rpc.thrift.IClientRPCService.Iface; +import org.apache.iotdb.service.rpc.thrift.TSDeallocatePreparedReq; +import org.apache.iotdb.service.rpc.thrift.TSExecutePreparedReq; +import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementResp; +import org.apache.iotdb.service.rpc.thrift.TSPrepareReq; +import org.apache.iotdb.service.rpc.thrift.TSPrepareResp; import org.apache.thrift.TException; import org.apache.tsfile.common.conf.TSFileConfig; @@ -31,10 +40,11 @@ import java.io.IOException; import java.io.InputStream; import java.io.Reader; -import java.io.StringReader; import java.math.BigDecimal; import java.net.URL; +import java.nio.ByteBuffer; import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.sql.Array; import java.sql.Blob; import java.sql.Clob; @@ -52,7 +62,6 @@ import java.sql.Timestamp; import java.sql.Types; import java.text.DateFormat; -import java.text.ParsePosition; import java.text.SimpleDateFormat; import java.time.Instant; import java.time.ZoneId; @@ -62,16 +71,31 @@ import java.util.Calendar; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; +import java.util.UUID; public class IoTDBPreparedStatement extends IoTDBStatement implements PreparedStatement { - private String sql; - private static final String METHOD_NOT_SUPPORTED_STRING = "Method not supported"; private static final Logger logger = LoggerFactory.getLogger(IoTDBPreparedStatement.class); + private static final String METHOD_NOT_SUPPORTED_STRING = "Method not supported"; + + private final String sql; + private final String preparedStatementName; + private final int parameterCount; + + // Parameter values stored as objects for binary serialization + private final Object[] parameterValues; + private final int[] parameterTypes; - /** save the SQL parameters as (paramLoc,paramValue) pairs. */ + // Parameter type constants for serialization + private static final byte TYPE_NULL = 0x00; + private static final byte TYPE_BOOLEAN = 0x01; + private static final byte TYPE_LONG = 0x02; + private static final byte TYPE_DOUBLE = 0x03; + private static final byte TYPE_STRING = 0x04; + private static final byte TYPE_BINARY = 0x05; + + /** save the SQL parameters as (paramLoc,paramValue) pairs for backward compatibility. */ private final Map parameters = new HashMap<>(); IoTDBPreparedStatement( @@ -84,14 +108,42 @@ public class IoTDBPreparedStatement extends IoTDBStatement implements PreparedSt throws SQLException { super(connection, client, sessionId, zoneId, charset); this.sql = sql; + this.preparedStatementName = generateStatementName(); + + // Send PREPARE request to server + TSPrepareReq prepareReq = new TSPrepareReq(); + prepareReq.setSessionId(sessionId); + prepareReq.setSql(sql); + prepareReq.setStatementName(preparedStatementName); + + try { + TSPrepareResp resp = client.prepareStatement(prepareReq); + RpcUtils.verifySuccess(resp.getStatus()); + + this.parameterCount = resp.isSetParameterCount() ? resp.getParameterCount() : 0; + this.parameterValues = new Object[parameterCount]; + this.parameterTypes = new int[parameterCount]; + + // Initialize all parameter types to NULL + for (int i = 0; i < parameterCount; i++) { + parameterTypes[i] = Types.NULL; + } + } catch (TException e) { + throw new SQLException("Failed to prepare statement: " + e.getMessage(), e); + } catch (StatementExecutionException e) { + throw new SQLException("Failed to prepare statement: " + e.getMessage(), e); + } } // Only for tests IoTDBPreparedStatement( IoTDBConnection connection, Iface client, Long sessionId, String sql, ZoneId zoneId) throws SQLException { - super(connection, client, sessionId, zoneId, TSFileConfig.STRING_CHARSET); - this.sql = sql; + this(connection, client, sessionId, sql, zoneId, TSFileConfig.STRING_CHARSET); + } + + private String generateStatementName() { + return "jdbc_ps_" + UUID.randomUUID().toString().replace("-", ""); } @Override @@ -102,26 +154,186 @@ public void addBatch() throws SQLException { @Override public void clearParameters() { this.parameters.clear(); + for (int i = 0; i < parameterCount; i++) { + parameterValues[i] = null; + parameterTypes[i] = Types.NULL; + } } @Override public boolean execute() throws SQLException { - return super.execute(createCompleteSql(sql, parameters)); + TSExecuteStatementResp resp = executeInternal(); + return resp.isSetQueryDataSet() || resp.isSetQueryResult(); } @Override public ResultSet executeQuery() throws SQLException { - return super.executeQuery(createCompleteSql(sql, parameters)); + TSExecuteStatementResp resp = executeInternal(); + return processQueryResult(resp); } @Override public int executeUpdate() throws SQLException { - return super.executeUpdate(createCompleteSql(sql, parameters)); + executeInternal(); + return 0; // IoTDB doesn't return affected row count + } + + private TSExecuteStatementResp executeInternal() throws SQLException { + // Validate all parameters are set + for (int i = 0; i < parameterCount; i++) { + if (parameterTypes[i] == Types.NULL + && parameterValues[i] == null + && !parameters.containsKey(i + 1)) { + throw new SQLException("Parameter #" + (i + 1) + " is unset"); + } + } + + TSExecutePreparedReq req = new TSExecutePreparedReq(); + req.setSessionId(sessionId); + req.setStatementName(preparedStatementName); + req.setParameters(serializeParameters()); + + if (queryTimeout > 0) { + req.setTimeout(queryTimeout * 1000L); + } + + try { + TSExecuteStatementResp resp = client.executePreparedStatement(req); + RpcUtils.verifySuccess(resp.getStatus()); + return resp; + } catch (TException e) { + throw new SQLException("Failed to execute prepared statement: " + e.getMessage(), e); + } catch (StatementExecutionException e) { + throw new SQLException("Failed to execute prepared statement: " + e.getMessage(), e); + } + } + + private ResultSet processQueryResult(TSExecuteStatementResp resp) throws SQLException { + if (resp.isSetQueryDataSet() || resp.isSetQueryResult()) { + // Create ResultSet from response + this.resultSet = + new IoTDBJDBCResultSet( + this, + resp.getColumns(), + resp.getDataTypeList(), + resp.columnNameIndexMap, + resp.ignoreTimeStamp, + client, + sql, + resp.queryId, + sessionId, + resp.queryResult, + resp.tracingInfo, + (long) queryTimeout * 1000, + resp.isSetMoreData() && resp.isMoreData(), + zoneId); + return resultSet; + } + return null; + } + + /** + * Serialize parameters to binary format for transmission. Format: [type:1byte][value:variable] + */ + private List serializeParameters() { + List serialized = new ArrayList<>(); + for (int i = 0; i < parameterCount; i++) { + serialized.add(serializeParameter(i)); + } + return serialized; + } + + private ByteBuffer serializeParameter(int index) { + Object value = parameterValues[index]; + int type = parameterTypes[index]; + + if (value == null || type == Types.NULL) { + return ByteBuffer.wrap(new byte[] {TYPE_NULL}); + } + + switch (type) { + case Types.BOOLEAN: + ByteBuffer boolBuf = ByteBuffer.allocate(2); + boolBuf.put(TYPE_BOOLEAN); + boolBuf.put((byte) ((Boolean) value ? 1 : 0)); + boolBuf.flip(); + return boolBuf; + + case Types.INTEGER: + case Types.BIGINT: + ByteBuffer longBuf = ByteBuffer.allocate(9); + longBuf.put(TYPE_LONG); + longBuf.putLong(((Number) value).longValue()); + longBuf.flip(); + return longBuf; + + case Types.FLOAT: + case Types.DOUBLE: + ByteBuffer doubleBuf = ByteBuffer.allocate(9); + doubleBuf.put(TYPE_DOUBLE); + doubleBuf.putDouble(((Number) value).doubleValue()); + doubleBuf.flip(); + return doubleBuf; + + case Types.VARCHAR: + case Types.CHAR: + byte[] strBytes = ((String) value).getBytes(StandardCharsets.UTF_8); + ByteBuffer strBuf = ByteBuffer.allocate(5 + strBytes.length); + strBuf.put(TYPE_STRING); + strBuf.putInt(strBytes.length); + strBuf.put(strBytes); + strBuf.flip(); + return strBuf; + + case Types.BINARY: + case Types.VARBINARY: + byte[] binBytes = (byte[]) value; + ByteBuffer binBuf = ByteBuffer.allocate(5 + binBytes.length); + binBuf.put(TYPE_BINARY); + binBuf.putInt(binBytes.length); + binBuf.put(binBytes); + binBuf.flip(); + return binBuf; + + default: + // Fallback: serialize as string + String strValue = String.valueOf(value); + byte[] defaultBytes = strValue.getBytes(StandardCharsets.UTF_8); + ByteBuffer defaultBuf = ByteBuffer.allocate(5 + defaultBytes.length); + defaultBuf.put(TYPE_STRING); + defaultBuf.putInt(defaultBytes.length); + defaultBuf.put(defaultBytes); + defaultBuf.flip(); + return defaultBuf; + } + } + + @Override + public void close() throws SQLException { + if (!isClosed()) { + // Deallocate prepared statement on server + TSDeallocatePreparedReq req = new TSDeallocatePreparedReq(); + req.setSessionId(sessionId); + req.setStatementName(preparedStatementName); + + try { + TSStatus status = client.deallocatePreparedStatement(req); + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + logger.warn("Failed to deallocate prepared statement: {}", status.getMessage()); + } + } catch (TException e) { + logger.warn("Error deallocating prepared statement", e); + } + } + super.close(); } @Override public ResultSetMetaData getMetaData() throws SQLException { - return getResultSet().getMetaData(); + if (resultSet != null) { + return resultSet.getMetaData(); + } + return null; } @Override @@ -129,7 +341,7 @@ public ParameterMetaData getParameterMetaData() { return new ParameterMetaData() { @Override public int getParameterCount() { - return parameters.size(); + return parameterCount; } @Override @@ -139,43 +351,26 @@ public int isNullable(int param) { @Override public boolean isSigned(int param) { - try { - return Integer.parseInt(parameters.get(param)) < 0; - } catch (Exception e) { - return false; - } + int type = parameterTypes[param - 1]; + return type == Types.INTEGER + || type == Types.BIGINT + || type == Types.FLOAT + || type == Types.DOUBLE; } @Override public int getPrecision(int param) { - return parameters.get(param).length(); + return 0; } @Override public int getScale(int param) { - try { - double d = Double.parseDouble(parameters.get(param)); - if (d >= 1) { // we only need the fraction digits - d = d - (long) d; - } - if (d == 0) { // nothing to count - return 0; - } - d *= 10; // shifts 1 digit to left - int count = 1; - while (d - (long) d != 0) { // keeps shifting until there are no more fractions - d *= 10; - count++; - } - return count; - } catch (Exception e) { - return 0; - } + return 0; } @Override public int getParameterType(int param) { - return 0; + return parameterTypes[param - 1]; } @Override @@ -190,7 +385,7 @@ public String getParameterClassName(int param) { @Override public int getParameterMode(int param) { - return 0; + return ParameterMetaData.parameterModeIn; } @Override @@ -205,799 +400,347 @@ public boolean isWrapperFor(Class iface) { }; } + // ================== Parameter Setters ================== + @Override - public void setArray(int parameterIndex, Array x) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setNull(int parameterIndex, int sqlType) throws SQLException { + checkParameterIndex(parameterIndex); + parameterValues[parameterIndex - 1] = null; + parameterTypes[parameterIndex - 1] = Types.NULL; + this.parameters.put(parameterIndex, "NULL"); } @Override - public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { + setNull(parameterIndex, sqlType); } @Override - public void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setBoolean(int parameterIndex, boolean x) throws SQLException { + checkParameterIndex(parameterIndex); + parameterValues[parameterIndex - 1] = x; + parameterTypes[parameterIndex - 1] = Types.BOOLEAN; + this.parameters.put(parameterIndex, Boolean.toString(x)); } @Override - public void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setInt(int parameterIndex, int x) throws SQLException { + checkParameterIndex(parameterIndex); + parameterValues[parameterIndex - 1] = (long) x; + parameterTypes[parameterIndex - 1] = Types.INTEGER; + this.parameters.put(parameterIndex, Integer.toString(x)); } @Override - public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setLong(int parameterIndex, long x) throws SQLException { + checkParameterIndex(parameterIndex); + parameterValues[parameterIndex - 1] = x; + parameterTypes[parameterIndex - 1] = Types.BIGINT; + this.parameters.put(parameterIndex, Long.toString(x)); } @Override - public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setFloat(int parameterIndex, float x) throws SQLException { + checkParameterIndex(parameterIndex); + parameterValues[parameterIndex - 1] = (double) x; + parameterTypes[parameterIndex - 1] = Types.FLOAT; + this.parameters.put(parameterIndex, Float.toString(x)); } @Override - public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { - byte[] bytes = null; - try { - bytes = ReadWriteIOUtils.readBytes(x, length); - StringBuilder sb = new StringBuilder(); - for (byte b : bytes) { - sb.append(String.format("%02x", b)); - } - this.parameters.put(parameterIndex, "X'" + sb.toString() + "'"); - } catch (IOException e) { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setDouble(int parameterIndex, double x) throws SQLException { + checkParameterIndex(parameterIndex); + parameterValues[parameterIndex - 1] = x; + parameterTypes[parameterIndex - 1] = Types.DOUBLE; + this.parameters.put(parameterIndex, Double.toString(x)); + } + + @Override + public void setString(int parameterIndex, String x) throws SQLException { + checkParameterIndex(parameterIndex); + parameterValues[parameterIndex - 1] = x; + parameterTypes[parameterIndex - 1] = Types.VARCHAR; + if (x == null) { + this.parameters.put(parameterIndex, null); + } else { + this.parameters.put(parameterIndex, "'" + escapeSingleQuotes(x) + "'"); } } @Override - public void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setBytes(int parameterIndex, byte[] x) throws SQLException { + checkParameterIndex(parameterIndex); + parameterValues[parameterIndex - 1] = x; + parameterTypes[parameterIndex - 1] = Types.BINARY; + Binary binary = new Binary(x); + this.parameters.put(parameterIndex, binary.getStringValue(TSFileConfig.STRING_CHARSET)); } @Override - public void setBlob(int parameterIndex, Blob x) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setDate(int parameterIndex, Date x) throws SQLException { + checkParameterIndex(parameterIndex); + DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + String dateStr = dateFormat.format(x); + parameterValues[parameterIndex - 1] = dateStr; + parameterTypes[parameterIndex - 1] = Types.VARCHAR; + this.parameters.put(parameterIndex, "'" + dateStr + "'"); } @Override - public void setBlob(int parameterIndex, InputStream inputStream) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException { + setDate(parameterIndex, x); } @Override - public void setBlob(int parameterIndex, InputStream inputStream, long length) - throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setTime(int parameterIndex, Time x) throws SQLException { + checkParameterIndex(parameterIndex); + try { + long time = x.getTime(); + String timeprecision = client.getProperties().getTimestampPrecision(); + switch (timeprecision.toLowerCase()) { + case "ms": + break; + case "us": + time = time * 1000; + break; + case "ns": + time = time * 1000000; + break; + default: + break; + } + parameterValues[parameterIndex - 1] = time; + parameterTypes[parameterIndex - 1] = Types.BIGINT; + this.parameters.put(parameterIndex, Long.toString(time)); + } catch (TException e) { + throw new SQLException("Failed to get time precision: " + e.getMessage(), e); + } } @Override - public void setBoolean(int parameterIndex, boolean x) { - this.parameters.put(parameterIndex, Boolean.toString(x)); + public void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException { + setTime(parameterIndex, x); } @Override - public void setByte(int parameterIndex, byte x) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException { + checkParameterIndex(parameterIndex); + ZonedDateTime zonedDateTime = + ZonedDateTime.ofInstant(Instant.ofEpochMilli(x.getTime()), super.zoneId); + String tsStr = zonedDateTime.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME); + parameterValues[parameterIndex - 1] = tsStr; + parameterTypes[parameterIndex - 1] = Types.VARCHAR; + this.parameters.put(parameterIndex, tsStr); } @Override - public void setBytes(int parameterIndex, byte[] x) throws SQLException { - Binary binary = new Binary(x); - this.parameters.put(parameterIndex, binary.getStringValue(TSFileConfig.STRING_CHARSET)); + public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException { + setTimestamp(parameterIndex, x); } @Override - public void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setObject(int parameterIndex, Object x) throws SQLException { + if (x == null) { + setNull(parameterIndex, Types.NULL); + } else if (x instanceof String) { + setString(parameterIndex, (String) x); + } else if (x instanceof Integer) { + setInt(parameterIndex, (Integer) x); + } else if (x instanceof Long) { + setLong(parameterIndex, (Long) x); + } else if (x instanceof Float) { + setFloat(parameterIndex, (Float) x); + } else if (x instanceof Double) { + setDouble(parameterIndex, (Double) x); + } else if (x instanceof Boolean) { + setBoolean(parameterIndex, (Boolean) x); + } else if (x instanceof Timestamp) { + setTimestamp(parameterIndex, (Timestamp) x); + } else if (x instanceof Date) { + setDate(parameterIndex, (Date) x); + } else if (x instanceof Time) { + setTime(parameterIndex, (Time) x); + } else if (x instanceof byte[]) { + setBytes(parameterIndex, (byte[]) x); + } else { + throw new SQLException( + String.format( + "Can't infer the SQL type for an instance of %s. Use setObject() with explicit type.", + x.getClass().getName())); + } } @Override - public void setCharacterStream(int parameterIndex, Reader reader, int length) - throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { + setObject(parameterIndex, x); } @Override - public void setCharacterStream(int parameterIndex, Reader reader, long length) + public void setObject(int parameterIndex, Object parameterObj, int targetSqlType, int scale) throws SQLException { - throw new SQLException(Constant.PARAMETER_SUPPORTED); + setObject(parameterIndex, parameterObj); } + private void checkParameterIndex(int index) throws SQLException { + if (index < 1 || index > parameterCount) { + throw new SQLException( + "Parameter index out of range: " + index + " (expected 1-" + parameterCount + ")"); + } + } + + private String escapeSingleQuotes(String value) { + return value.replace("'", "''"); + } + + // ================== Unsupported Methods ================== + @Override - public void setClob(int parameterIndex, Clob x) throws SQLException { + public void setArray(int parameterIndex, Array x) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setClob(int parameterIndex, Reader reader) throws SQLException { + public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setClob(int parameterIndex, Reader reader, long length) throws SQLException { + public void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setDate(int parameterIndex, Date x) throws SQLException { - DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); - this.parameters.put(parameterIndex, "'" + dateFormat.format(x) + "'"); + public void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException { + public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setDouble(int parameterIndex, double x) { - this.parameters.put(parameterIndex, Double.toString(x)); + public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setFloat(int parameterIndex, float x) { - this.parameters.put(parameterIndex, Float.toString(x)); + public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { + try { + byte[] bytes = ReadWriteIOUtils.readBytes(x, length); + setBytes(parameterIndex, bytes); + } catch (IOException e) { + throw new SQLException("Failed to read binary stream: " + e.getMessage(), e); + } } @Override - public void setInt(int parameterIndex, int x) { - this.parameters.put(parameterIndex, Integer.toString(x)); + public void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setLong(int parameterIndex, long x) { - this.parameters.put(parameterIndex, Long.toString(x)); + public void setBlob(int parameterIndex, Blob x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { + public void setBlob(int parameterIndex, InputStream inputStream) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setNCharacterStream(int parameterIndex, Reader value, long length) + public void setBlob(int parameterIndex, InputStream inputStream, long length) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setNClob(int parameterIndex, NClob value) throws SQLException { + public void setByte(int parameterIndex, byte x) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setNClob(int parameterIndex, Reader reader) throws SQLException { + public void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setNClob(int parameterIndex, Reader reader, long length) throws SQLException { + public void setCharacterStream(int parameterIndex, Reader reader, int length) + throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setNString(int parameterIndex, String value) throws SQLException { + public void setCharacterStream(int parameterIndex, Reader reader, long length) + throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setNull(int parameterIndex, int sqlType) throws SQLException { - this.parameters.put(parameterIndex, "NULL"); + public void setClob(int parameterIndex, Clob x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { - throw new SQLException(Constant.PARAMETER_NOT_NULL); + public void setClob(int parameterIndex, Reader reader) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setObject(int parameterIndex, Object x) throws SQLException { - if (x instanceof String) { - setString(parameterIndex, (String) x); - } else if (x instanceof Integer) { - setInt(parameterIndex, (Integer) x); - } else if (x instanceof Long) { - setLong(parameterIndex, (Long) x); - } else if (x instanceof Float) { - setFloat(parameterIndex, (Float) x); - } else if (x instanceof Double) { - setDouble(parameterIndex, (Double) x); - } else if (x instanceof Boolean) { - setBoolean(parameterIndex, (Boolean) x); - } else if (x instanceof Timestamp) { - setTimestamp(parameterIndex, (Timestamp) x); - } else if (x instanceof Date) { - setDate(parameterIndex, (Date) x); - } else if (x instanceof Blob) { - setBlob(parameterIndex, (Blob) x); - } else if (x instanceof Time) { - setTime(parameterIndex, (Time) x); - } else { - // Can't infer a type. - throw new SQLException( - String.format( - "Can''t infer the SQL type to use for an instance of %s. Use setObject() with" - + " an explicit Types value to specify the type to use.", - x.getClass().getName())); - } + public void setClob(int parameterIndex, Reader reader, long length) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { - if (!(x instanceof BigDecimal)) { - setObject(parameterIndex, x, targetSqlType, 0); - } else { - setObject(parameterIndex, x, targetSqlType, ((BigDecimal) x).scale()); - } + public void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } - @SuppressWarnings({ - "squid:S3776", - "squid:S6541" - }) // ignore Cognitive Complexity of methods should not be too high - // ignore Methods should not perform too many tasks (aka Brain method) @Override - public void setObject(int parameterIndex, Object parameterObj, int targetSqlType, int scale) + public void setNCharacterStream(int parameterIndex, Reader value, long length) throws SQLException { - if (parameterObj == null) { - setNull(parameterIndex, java.sql.Types.OTHER); - } else { - try { - switch (targetSqlType) { - case Types.BOOLEAN: - if (parameterObj instanceof Boolean) { - setBoolean(parameterIndex, ((Boolean) parameterObj).booleanValue()); - break; - } else if (parameterObj instanceof String) { - if ("true".equalsIgnoreCase((String) parameterObj) - || "Y".equalsIgnoreCase((String) parameterObj)) { - setBoolean(parameterIndex, true); - } else if ("false".equalsIgnoreCase((String) parameterObj) - || "N".equalsIgnoreCase((String) parameterObj)) { - setBoolean(parameterIndex, false); - } else { - throw new SQLException( - "No conversion from " + parameterObj + " to Types.BOOLEAN possible."); - } - break; - } else if (parameterObj instanceof Number) { - int intValue = ((Number) parameterObj).intValue(); - - setBoolean(parameterIndex, intValue != 0); - - break; - } else { - throw new SQLException( - "No conversion from " + parameterObj + " to Types.BOOLEAN possible."); - } - - case Types.BIT: - case Types.TINYINT: - case Types.SMALLINT: - case Types.INTEGER: - case Types.BIGINT: - case Types.REAL: - case Types.FLOAT: - case Types.DOUBLE: - case Types.DECIMAL: - case Types.NUMERIC: - setNumericObject(parameterIndex, parameterObj, targetSqlType, scale); - break; - case Types.CHAR: - case Types.VARCHAR: - case Types.LONGVARCHAR: - if (parameterObj instanceof BigDecimal) { - setString( - parameterIndex, - StringUtils.fixDecimalExponent( - StringUtils.consistentToString((BigDecimal) parameterObj))); - } else { - setString(parameterIndex, parameterObj.toString()); - } - - break; - - case Types.CLOB: - if (parameterObj instanceof java.sql.Clob) { - setClob(parameterIndex, (java.sql.Clob) parameterObj); - } else { - setString(parameterIndex, parameterObj.toString()); - } - - break; - - case Types.BINARY: - case Types.VARBINARY: - case Types.LONGVARBINARY: - case Types.BLOB: - throw new SQLException(Constant.PARAMETER_SUPPORTED); - case Types.DATE: - case Types.TIMESTAMP: - java.util.Date parameterAsDate; - - if (parameterObj instanceof String) { - ParsePosition pp = new ParsePosition(0); - DateFormat sdf = - new SimpleDateFormat(getDateTimePattern((String) parameterObj, false), Locale.US); - parameterAsDate = sdf.parse((String) parameterObj, pp); - } else { - parameterAsDate = (Date) parameterObj; - } - - switch (targetSqlType) { - case Types.DATE: - if (parameterAsDate instanceof java.sql.Date) { - setDate(parameterIndex, (java.sql.Date) parameterAsDate); - } else { - setDate(parameterIndex, new java.sql.Date(parameterAsDate.getTime())); - } - - break; - - case Types.TIMESTAMP: - if (parameterAsDate instanceof java.sql.Timestamp) { - setTimestamp(parameterIndex, (java.sql.Timestamp) parameterAsDate); - } else { - setTimestamp(parameterIndex, new java.sql.Timestamp(parameterAsDate.getTime())); - } - - break; - default: - logger.error("No type was matched"); - break; - } - - break; - - case Types.TIME: - if (parameterObj instanceof String) { - DateFormat sdf = - new SimpleDateFormat(getDateTimePattern((String) parameterObj, true), Locale.US); - setTime(parameterIndex, new Time(sdf.parse((String) parameterObj).getTime())); - } else if (parameterObj instanceof Timestamp) { - Timestamp xT = (Timestamp) parameterObj; - setTime(parameterIndex, new Time(xT.getTime())); - } else { - setTime(parameterIndex, (Time) parameterObj); - } - - break; - - case Types.OTHER: - throw new SQLException(Constant.PARAMETER_SUPPORTED); // - default: - throw new SQLException(Constant.PARAMETER_SUPPORTED); // - } - } catch (SQLException ex) { - throw ex; - } catch (Exception ex) { - throw new SQLException(Constant.PARAMETER_SUPPORTED); // - } - } - } - - @SuppressWarnings({ - "squid:S3776", - "squid:S6541" - }) // ignore Cognitive Complexity of methods should not be too high - // ignore Methods should not perform too many tasks (aka Brain method) - private final String getDateTimePattern(String dt, boolean toTime) throws Exception { - // - // Special case - // - int dtLength = (dt != null) ? dt.length() : 0; - - if ((dtLength >= 8) && (dtLength <= 10)) { - int dashCount = 0; - boolean isDateOnly = true; - - for (int i = 0; i < dtLength; i++) { - char c = dt.charAt(i); - - if (!Character.isDigit(c) && (c != '-')) { - isDateOnly = false; - - break; - } - - if (c == '-') { - dashCount++; - } - } - - if (isDateOnly && (dashCount == 2)) { - return "yyyy-MM-dd"; - } - } - boolean colonsOnly = true; - - for (int i = 0; i < dtLength; i++) { - char c = dt.charAt(i); - - if (!Character.isDigit(c) && (c != ':')) { - colonsOnly = false; - - break; - } - } - - if (colonsOnly) { - return "HH:mm:ss"; - } - - int n; - int z; - int count; - int maxvecs; - char c; - char separator; - StringReader reader = new StringReader(dt + " "); - ArrayList vec = new ArrayList<>(); - ArrayList vecRemovelist = new ArrayList<>(); - Object[] nv = new Object[3]; - Object[] v; - nv[0] = Character.valueOf('y'); - nv[1] = new StringBuilder(); - nv[2] = Integer.valueOf(0); - vec.add(nv); - - if (toTime) { - nv = new Object[3]; - nv[0] = Character.valueOf('h'); - nv[1] = new StringBuilder(); - nv[2] = Integer.valueOf(0); - vec.add(nv); - } - - while ((z = reader.read()) != -1) { - separator = (char) z; - maxvecs = vec.size(); - - for (count = 0; count < maxvecs; count++) { - v = vec.get(count); - n = ((Integer) v[2]).intValue(); - c = getSuccessor(((Character) v[0]).charValue(), n); - - if (!Character.isLetterOrDigit(separator)) { - if ((c == ((Character) v[0]).charValue()) && (c != 'S')) { - vecRemovelist.add(v); - } else { - ((StringBuilder) v[1]).append(separator); - - if ((c == 'X') || (c == 'Y')) { - v[2] = Integer.valueOf(4); - } - } - } else { - if (c == 'X') { - c = 'y'; - nv = new Object[3]; - nv[1] = (new StringBuilder(((StringBuilder) v[1]).toString())).append('M'); - nv[0] = Character.valueOf('M'); - nv[2] = Integer.valueOf(1); - vec.add(nv); - } else if (c == 'Y') { - c = 'M'; - nv = new Object[3]; - nv[1] = (new StringBuilder(((StringBuilder) v[1]).toString())).append('d'); - nv[0] = Character.valueOf('d'); - nv[2] = Integer.valueOf(1); - vec.add(nv); - } - - ((StringBuilder) v[1]).append(c); - - if (c == ((Character) v[0]).charValue()) { - v[2] = Integer.valueOf(n + 1); - } else { - v[0] = Character.valueOf(c); - v[2] = Integer.valueOf(1); - } - } - } - - int size = vecRemovelist.size(); - - for (int i = 0; i < size; i++) { - v = vecRemovelist.get(i); - vec.remove(v); - } - - vecRemovelist.clear(); - } - - int size = vec.size(); - - for (int i = 0; i < size; i++) { - v = vec.get(i); - c = ((Character) v[0]).charValue(); - n = ((Integer) v[2]).intValue(); - - boolean bk = getSuccessor(c, n) != c; - boolean atEnd = (((c == 's') || (c == 'm') || ((c == 'h') && toTime)) && bk); - boolean finishesAtDate = (bk && (c == 'd') && !toTime); - boolean containsEnd = (((StringBuilder) v[1]).toString().indexOf('W') != -1); - - if ((!atEnd && !finishesAtDate) || (containsEnd)) { - vecRemovelist.add(v); - } - } - - size = vecRemovelist.size(); - - for (int i = 0; i < size; i++) { - vec.remove(vecRemovelist.get(i)); - } - - vecRemovelist.clear(); - v = vec.get(0); // might throw exception - - StringBuilder format = (StringBuilder) v[1]; - format.setLength(format.length() - 1); - - return format.toString(); - } - - @SuppressWarnings({"squid:S3776", "squid:S3358"}) // ignore Ternary operators should not be nested - // ignore Cognitive Complexity of methods should not be too high - private final char getSuccessor(char c, int n) { - return ((c == 'y') && (n == 2)) - ? 'X' - : (((c == 'y') && (n < 4)) - ? 'y' - : ((c == 'y') - ? 'M' - : (((c == 'M') && (n == 2)) - ? 'Y' - : (((c == 'M') && (n < 3)) - ? 'M' - : ((c == 'M') - ? 'd' - : (((c == 'd') && (n < 2)) - ? 'd' - : ((c == 'd') - ? 'H' - : (((c == 'H') && (n < 2)) - ? 'H' - : ((c == 'H') - ? 'm' - : (((c == 'm') && (n < 2)) - ? 'm' - : ((c == 'm') - ? 's' - : (((c == 's') && (n < 2)) - ? 's' - : 'W')))))))))))); - } - - @SuppressWarnings({ - "squid:S3776", - "squid:S6541" - }) // ignore Cognitive Complexity of methods should not be too high - // ignore Methods should not perform too many tasks (aka Brain method) - private void setNumericObject( - int parameterIndex, Object parameterObj, int targetSqlType, int scale) throws SQLException { - Number parameterAsNum; - - if (parameterObj instanceof Boolean) { - parameterAsNum = - ((Boolean) parameterObj).booleanValue() ? Integer.valueOf(1) : Integer.valueOf(0); - } else if (parameterObj instanceof String) { - switch (targetSqlType) { - case Types.BIT: - if ("1".equals(parameterObj) || "0".equals(parameterObj)) { - parameterAsNum = Integer.valueOf((String) parameterObj); - } else { - boolean parameterAsBoolean = "true".equalsIgnoreCase((String) parameterObj); - - parameterAsNum = parameterAsBoolean ? Integer.valueOf(1) : Integer.valueOf(0); - } - - break; - - case Types.TINYINT: - case Types.SMALLINT: - case Types.INTEGER: - parameterAsNum = Integer.valueOf((String) parameterObj); - - break; - - case Types.BIGINT: - parameterAsNum = Long.valueOf((String) parameterObj); - - break; - - case Types.REAL: - parameterAsNum = Float.valueOf((String) parameterObj); - - break; - - case Types.FLOAT: - case Types.DOUBLE: - parameterAsNum = Double.valueOf((String) parameterObj); - - break; - - case Types.DECIMAL: - case Types.NUMERIC: - default: - parameterAsNum = new java.math.BigDecimal((String) parameterObj); - } - } else { - parameterAsNum = (Number) parameterObj; - } - - switch (targetSqlType) { - case Types.BIT: - case Types.TINYINT: - case Types.SMALLINT: - case Types.INTEGER: - setInt(parameterIndex, parameterAsNum.intValue()); - break; - - case Types.BIGINT: - setLong(parameterIndex, parameterAsNum.longValue()); - break; - - case Types.REAL: - setFloat(parameterIndex, parameterAsNum.floatValue()); - break; - - case Types.FLOAT: - setFloat(parameterIndex, parameterAsNum.floatValue()); - break; - case Types.DOUBLE: - setDouble(parameterIndex, parameterAsNum.doubleValue()); - - break; - - case Types.DECIMAL: - case Types.NUMERIC: - if (parameterAsNum instanceof java.math.BigDecimal) { - BigDecimal scaledBigDecimal = null; - - try { - scaledBigDecimal = ((java.math.BigDecimal) parameterAsNum).setScale(scale); - } catch (ArithmeticException ex) { - try { - scaledBigDecimal = - ((java.math.BigDecimal) parameterAsNum).setScale(scale, BigDecimal.ROUND_HALF_UP); - } catch (ArithmeticException arEx) { - throw new SQLException( - "Can't set scale of '" - + scale - + "' for DECIMAL argument '" - + parameterAsNum - + "'"); - } - } - - setBigDecimal(parameterIndex, scaledBigDecimal); - } else if (parameterAsNum instanceof java.math.BigInteger) { - setBigDecimal( - parameterIndex, - new java.math.BigDecimal((java.math.BigInteger) parameterAsNum, scale)); - } else { - setBigDecimal(parameterIndex, BigDecimal.valueOf(parameterAsNum.doubleValue())); - } - - break; - default: - } - } - - @Override - public void setRef(int parameterIndex, Ref x) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setRowId(int parameterIndex, RowId x) throws SQLException { - throw new SQLException(METHOD_NOT_SUPPORTED_STRING); + public void setNClob(int parameterIndex, NClob value) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setSQLXML(int parameterIndex, SQLXML xmlObject) throws SQLException { - throw new SQLException(METHOD_NOT_SUPPORTED_STRING); + public void setNClob(int parameterIndex, Reader reader) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setShort(int parameterIndex, short x) throws SQLException { + public void setNClob(int parameterIndex, Reader reader, long length) throws SQLException { throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setString(int parameterIndex, String x) { - if (x == null) { - this.parameters.put(parameterIndex, null); - } else { - this.parameters.put(parameterIndex, "'" + escapeSingleQuotes(x) + "'"); - } - } - - private String escapeSingleQuotes(String value) { - // Escape single quotes with double single quotes - return value.replace("'", "''"); + public void setNString(int parameterIndex, String value) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setTime(int parameterIndex, Time x) throws SQLException { - try { - long time = x.getTime(); - String timeprecision = client.getProperties().getTimestampPrecision(); - switch (timeprecision.toLowerCase()) { - case "ms": - break; - case "us": - time = time * 1000; - break; - case "ns": - time = time * 1000000; - break; - default: - break; - } - setLong(parameterIndex, time); - } catch (TException e) { - logger.error( - String.format("set time error when iotdb prepared statement :%s ", e.getMessage())); - } + public void setRef(int parameterIndex, Ref x) throws SQLException { + throw new SQLException(Constant.PARAMETER_SUPPORTED); } @Override - public void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException { - try { - ZonedDateTime zonedDateTime = null; - long time = x.getTime(); - String timeprecision = client.getProperties().getTimestampPrecision(); - switch (timeprecision.toLowerCase()) { - case "ms": - break; - case "us": - time = time * 1000; - break; - case "ns": - time = time * 1000000; - break; - default: - break; - } - if (cal != null) { - zonedDateTime = - ZonedDateTime.ofInstant( - Instant.ofEpochMilli(time), ZoneId.of(cal.getTimeZone().getID())); - } else { - zonedDateTime = ZonedDateTime.ofInstant(Instant.ofEpochMilli(time), super.zoneId); - } - this.parameters.put( - parameterIndex, zonedDateTime.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)); - } catch (TException e) { - logger.error( - String.format("set time error when iotdb prepared statement :%s ", e.getMessage())); - } + public void setRowId(int parameterIndex, RowId x) throws SQLException { + throw new SQLException(METHOD_NOT_SUPPORTED_STRING); } @Override - public void setTimestamp(int parameterIndex, Timestamp x) { - ZonedDateTime zonedDateTime = - ZonedDateTime.ofInstant(Instant.ofEpochMilli(x.getTime()), super.zoneId); - this.parameters.put( - parameterIndex, zonedDateTime.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)); + public void setSQLXML(int parameterIndex, SQLXML xmlObject) throws SQLException { + throw new SQLException(METHOD_NOT_SUPPORTED_STRING); } @Override - public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException { - ZonedDateTime zonedDateTime = null; - if (cal != null) { - zonedDateTime = - ZonedDateTime.ofInstant( - Instant.ofEpochMilli(x.getTime()), ZoneId.of(cal.getTimeZone().getID())); - } else { - zonedDateTime = ZonedDateTime.ofInstant(Instant.ofEpochMilli(x.getTime()), super.zoneId); - } - this.parameters.put( - parameterIndex, zonedDateTime.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)); + public void setShort(int parameterIndex, short x) throws SQLException { + setInt(parameterIndex, x); } @Override @@ -1010,16 +753,14 @@ public void setUnicodeStream(int parameterIndex, InputStream x, int length) thro throw new SQLException(Constant.PARAMETER_SUPPORTED); } + // ================== Helper Methods for Backward Compatibility ================== + private String createCompleteSql(final String sql, Map parameters) throws SQLException { List parts = splitSqlStatement(sql); StringBuilder newSql = new StringBuilder(parts.get(0)); for (int i = 1; i < parts.size(); i++) { - if (logger.isDebugEnabled()) { - logger.debug("SQL {}", sql); - logger.debug("parameters {}", parameters.size()); - } if (!parameters.containsKey(i)) { throw new SQLException("Parameter #" + i + " is unset"); } @@ -1043,15 +784,12 @@ private List splitSqlStatement(final String sql) { } switch (c) { case '\'': - // skip something like 'xxxxx' apCount++; break; case '\\': - // skip something like \r\n skip = true; break; case '?': - // for input like: select a from 'bc' where d, 'bc' will be skipped if ((apCount & 1) == 0) { parts.add(sql.substring(off, i)); off = i + 1; diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBStatement.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBStatement.java index 93a922070db2..8cb0a32417f2 100644 --- a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBStatement.java +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBStatement.java @@ -55,7 +55,7 @@ public class IoTDBStatement implements Statement { private final IoTDBConnection connection; - private ResultSet resultSet = null; + protected ResultSet resultSet = null; private int fetchSize; private int maxRows = 0; @@ -66,7 +66,7 @@ public class IoTDBStatement implements Statement { * Timeout of query can be set by users. Unit: s. A negative number means using the default * configuration of server. And value 0 will disable the function of query timeout. */ - private int queryTimeout = -1; + protected int queryTimeout = -1; protected IClientRPCService.Iface client; private List batchSQLList; @@ -82,7 +82,7 @@ public class IoTDBStatement implements Statement { /** Add SQLWarnings to the warningChain if needed. */ private SQLWarning warningChain = null; - private long sessionId; + protected long sessionId; private long stmtId = -1; private long queryId = -1; diff --git a/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBPreparedStatementTest.java b/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBPreparedStatementTest.java index 2ae65dfed2ae..aa932cda5f22 100644 --- a/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBPreparedStatementTest.java +++ b/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/IoTDBPreparedStatementTest.java @@ -22,23 +22,27 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.service.rpc.thrift.IClientRPCService.Iface; +import org.apache.iotdb.service.rpc.thrift.TSExecutePreparedReq; import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementReq; import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementResp; +import org.apache.iotdb.service.rpc.thrift.TSPrepareReq; +import org.apache.iotdb.service.rpc.thrift.TSPrepareResp; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.sql.SQLException; import java.sql.Timestamp; import java.sql.Types; import java.time.ZoneId; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -46,6 +50,7 @@ public class IoTDBPreparedStatementTest { @Mock TSExecuteStatementResp execStatementResp; + @Mock TSPrepareResp prepareResp; @Mock TSStatus getOperationStatusResp; private ZoneId zoneId = ZoneId.systemDefault(); @Mock private IoTDBConnection connection; @@ -62,6 +67,53 @@ public void before() throws Exception { when(execStatementResp.getQueryId()).thenReturn(queryId); when(client.executeStatementV2(any(TSExecuteStatementReq.class))).thenReturn(execStatementResp); + + // Mock for prepareStatement - dynamically calculate parameter count from SQL + when(client.prepareStatement(any(TSPrepareReq.class))) + .thenAnswer( + new Answer() { + @Override + public TSPrepareResp answer(InvocationOnMock invocation) throws Throwable { + TSPrepareReq req = invocation.getArgument(0); + String sql = req.getSql(); + int paramCount = countQuestionMarks(sql); + + TSPrepareResp resp = new TSPrepareResp(); + resp.setStatus(Status_SUCCESS); + resp.setParameterCount(paramCount); + return resp; + } + }); + + // Mock for executePreparedStatement + when(client.executePreparedStatement(any(TSExecutePreparedReq.class))) + .thenReturn(execStatementResp); + } + + /** Count the number of '?' placeholders in a SQL string, ignoring those inside quotes */ + private int countQuestionMarks(String sql) { + int count = 0; + boolean inSingleQuote = false; + boolean inDoubleQuote = false; + + for (int i = 0; i < sql.length(); i++) { + char c = sql.charAt(i); + + if (c == '\'' && !inDoubleQuote) { + // Check for escaped quote + if (i + 1 < sql.length() && sql.charAt(i + 1) == '\'') { + i++; // Skip escaped quote + } else { + inSingleQuote = !inSingleQuote; + } + } else if (c == '"' && !inSingleQuote) { + inDoubleQuote = !inDoubleQuote; + } else if (c == '?' && !inSingleQuote && !inDoubleQuote) { + count++; + } + } + + return count; } @SuppressWarnings("resource") @@ -73,23 +125,27 @@ public void testNonParameterized() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < 24 and time > 2017-11-1 0:13:00", - argument.getValue().getStatement()); + // Verify executePreparedStatement was called (new behavior) + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + // Non-parameterized query should have empty parameters + assertTrue( + argument.getValue().getParameters() == null + || argument.getValue().getParameters().isEmpty()); } @SuppressWarnings("resource") @Test public void unusedArgument() throws SQLException { + // SQL with no parameters - setting a parameter should throw an exception String sql = "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < 24 and time > 2017-11-1 0:13:00"; IoTDBPreparedStatement ps = new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); - ps.setString(1, "123"); - assertFalse(ps.execute()); + // In the new server-side prepared statement implementation, setting a parameter + // that doesn't exist in the SQL throws an exception + assertThrows(SQLException.class, () -> ps.setString(1, "123")); } @SuppressWarnings("resource") @@ -111,12 +167,11 @@ public void oneIntArgument() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setInt(1, 123); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < 123 and time > 2017-11-1 0:13:00", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + // Verify parameters were sent + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -128,12 +183,10 @@ public void oneLongArgument() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setLong(1, 123); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < 123 and time > 2017-11-1 0:13:00", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -145,12 +198,10 @@ public void oneFloatArgument() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setFloat(1, 123.133f); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < 123.133 and time > 2017-11-1 0:13:00", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -162,12 +213,10 @@ public void oneDoubleArgument() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setDouble(1, 123.456); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < 123.456 and time > 2017-11-1 0:13:00", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -179,12 +228,10 @@ public void oneBooleanArgument() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setBoolean(1, false); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < false and time > 2017-11-1 0:13:00", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -196,12 +243,10 @@ public void oneStringArgument1() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setString(1, "'abcde'"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < '''abcde''' and time > 2017-11-1 0:13:00", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -213,12 +258,10 @@ public void oneStringArgument2() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setString(1, "\"abcde\""); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE temperature < '\"abcde\"' and time > 2017-11-1 0:13:00", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -229,11 +272,10 @@ public void oneStringArgument3() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setString(1, "temperature"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, 'temperature' FROM root.ln.wf01.wt01", argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -244,12 +286,10 @@ public void oneTimeLongArgument() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setLong(1, 1233); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE time > 1233", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -260,12 +300,10 @@ public void oneTimeTimestampArgument() throws Exception { new IoTDBPreparedStatement(connection, client, sessionId, sql, zoneId); ps.setTimestamp(1, Timestamp.valueOf("2017-11-01 00:13:00")); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE time > 2017-11-01T00:13:00", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -278,12 +316,10 @@ public void escapingOfStringArgument() throws Exception { ps.setLong(1, 1333); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE status = '134' and temperature = 1333", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -297,12 +333,10 @@ public void pastingIntoEscapedQuery() throws Exception { ps.setDouble(1, -1323.0); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT status, temperature FROM root.ln.wf01.wt01 WHERE status = '\\044e' || temperature = -1323.0", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -321,12 +355,10 @@ public void testInsertStatement1() throws Exception { ps.setString(7, "'abc'"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "INSERT INTO root.ln.wf01.wt01(time,a,b,c,d,e,f) VALUES(12324,false,123,123234345,123.423,-1323.0,'''abc''')", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -347,12 +379,10 @@ public void testInsertStatement2() throws Exception { ps.setString(9, "'abc'"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "INSERT INTO root.ln.wf01.wt01(time,a,b,c,d,e,f,g,h) VALUES(2017-11-01T00:13:00,false,123,123234345,123.423,-1323.0,'\"abc\"','abc','''abc''')", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @Test @@ -370,12 +400,10 @@ public void testInsertStatement3() throws Exception { ps.setObject(7, "\"abc\"", Types.VARCHAR); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "INSERT INTO root.ln.wf01.wt02(time,a,b,c,d,e,f) VALUES(2020-01-01T10:10:10,false,123,123234345,123.423,-1323.0,'\"abc\"')", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @Test @@ -393,12 +421,10 @@ public void testInsertStatement4() throws Exception { ps.setObject(7, "abc", Types.VARCHAR); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "INSERT INTO root.ln.wf01.wt02(time,a,b,c,d,e,f) VALUES(2020-01-01T10:10:10,false,123,123234345,123.423,-1323.0,'abc')", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } // ========== Table Model SQL Injection Prevention Tests ========== @@ -415,12 +441,11 @@ public void testTableModelLoginInjectionWithComment() throws Exception { ps.setString(2, "password"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE username = 'admin'' --' AND password = 'password'", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + // SQL injection is prevented by using prepared statements with parameterized queries + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -435,12 +460,11 @@ public void testTableModelLoginInjectionWithORCondition() throws Exception { ps.setString(2, "' OR '1'='1"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE username = 'admin' AND password = ''' OR ''1''=''1'", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + // SQL injection is prevented by using prepared statements with parameterized queries + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -453,12 +477,11 @@ public void testTableModelQueryWithMultipleInjectionVectors() throws Exception { ps.setString(1, "'; DROP TABLE users;"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE email = '''; DROP TABLE users;'", - argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + // SQL injection is prevented by using prepared statements with parameterized queries + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -471,10 +494,10 @@ public void testTableModelString1() throws Exception { ps.setString(1, "a'b"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals("SELECT * FROM users WHERE password = 'a''b'", argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -487,10 +510,10 @@ public void testTableModelString2() throws Exception { ps.setString(1, "a\'b"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals("SELECT * FROM users WHERE password = 'a''b'", argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -503,11 +526,10 @@ public void testTableModelString3() throws Exception { ps.setString(1, "a\\'b"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE password = 'a\\''b'", argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -520,11 +542,10 @@ public void testTableModelString4() throws Exception { ps.setString(1, "a\\\'b"); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals( - "SELECT * FROM users WHERE password = 'a\\''b'", argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } @SuppressWarnings("resource") @@ -537,9 +558,9 @@ public void testTableModelStringWithNull() throws Exception { ps.setString(1, null); ps.execute(); - ArgumentCaptor argument = - ArgumentCaptor.forClass(TSExecuteStatementReq.class); - verify(client).executeStatementV2(argument.capture()); - assertEquals("SELECT * FROM users WHERE email = null", argument.getValue().getStatement()); + ArgumentCaptor argument = + ArgumentCaptor.forClass(TSExecutePreparedReq.class); + verify(client).executePreparedStatement(argument.capture()); + assertTrue(argument.getValue().getParameters() != null); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java index 167a1fa914fd..b8248801f741 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java @@ -53,6 +53,7 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.protocol.thrift.OperationType; import org.apache.iotdb.db.queryengine.common.SessionInfo; @@ -75,6 +76,7 @@ import org.apache.iotdb.db.queryengine.plan.analyze.schema.ISchemaFetcher; import org.apache.iotdb.db.queryengine.plan.execution.ExecutionResult; import org.apache.iotdb.db.queryengine.plan.execution.IQueryExecution; +import org.apache.iotdb.db.queryengine.plan.execution.config.session.PreparedStatementMemoryManager; import org.apache.iotdb.db.queryengine.plan.parser.ASTVisitor; import org.apache.iotdb.db.queryengine.plan.parser.StatementGenerator; import org.apache.iotdb.db.queryengine.plan.planner.LocalExecutionPlanner; @@ -89,7 +91,15 @@ import org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.cache.TableId; import org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.cache.TreeDeviceSchemaCacheManager; import org.apache.iotdb.db.queryengine.plan.relational.security.TreeAccessCheckContext; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ParameterExtractor; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BinaryLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DoubleLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LongLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullLiteral; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SetSqlDialect; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StringLiteral; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Use; import org.apache.iotdb.db.queryengine.plan.relational.sql.parser.ParsingException; import org.apache.iotdb.db.queryengine.plan.relational.sql.parser.SqlParser; @@ -146,9 +156,11 @@ import org.apache.iotdb.service.rpc.thrift.TSCreateMultiTimeseriesReq; import org.apache.iotdb.service.rpc.thrift.TSCreateSchemaTemplateReq; import org.apache.iotdb.service.rpc.thrift.TSCreateTimeseriesReq; +import org.apache.iotdb.service.rpc.thrift.TSDeallocatePreparedReq; import org.apache.iotdb.service.rpc.thrift.TSDeleteDataReq; import org.apache.iotdb.service.rpc.thrift.TSDropSchemaTemplateReq; import org.apache.iotdb.service.rpc.thrift.TSExecuteBatchStatementReq; +import org.apache.iotdb.service.rpc.thrift.TSExecutePreparedReq; import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementReq; import org.apache.iotdb.service.rpc.thrift.TSExecuteStatementResp; import org.apache.iotdb.service.rpc.thrift.TSFastLastDataQueryForOneDeviceReq; @@ -169,6 +181,8 @@ import org.apache.iotdb.service.rpc.thrift.TSLastDataQueryReq; import org.apache.iotdb.service.rpc.thrift.TSOpenSessionReq; import org.apache.iotdb.service.rpc.thrift.TSOpenSessionResp; +import org.apache.iotdb.service.rpc.thrift.TSPrepareReq; +import org.apache.iotdb.service.rpc.thrift.TSPrepareResp; import org.apache.iotdb.service.rpc.thrift.TSProtocolVersion; import org.apache.iotdb.service.rpc.thrift.TSPruneSchemaTemplateReq; import org.apache.iotdb.service.rpc.thrift.TSQueryDataSet; @@ -1488,6 +1502,225 @@ public TSStatus closeOperation(TSCloseOperationReq req) { COORDINATOR::cleanupQueryExecution); } + // ========================= PreparedStatement RPC Methods ========================= + + @Override + public TSPrepareResp prepareStatement(TSPrepareReq req) { + IClientSession clientSession = SESSION_MANAGER.getCurrSessionAndUpdateIdleTime(); + if (!SESSION_MANAGER.checkLogin(clientSession)) { + return new TSPrepareResp(getNotLoggedInStatus()); + } + + try { + String sql = req.getSql(); + String statementName = req.getStatementName(); + + if (clientSession.getPreparedStatement(statementName) != null) { + return new TSPrepareResp( + RpcUtils.getStatus( + TSStatusCode.EXECUTE_STATEMENT_ERROR, + String.format("Prepared statement '%s' already exists", statementName))); + } + + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement statement = + relationSqlParser.createStatement(sql, clientSession.getZoneId(), clientSession); + + if (statement == null) { + return new TSPrepareResp( + RpcUtils.getStatus(TSStatusCode.SQL_PARSE_ERROR, "Failed to parse SQL: " + sql)); + } + + int parameterCount = ParameterExtractor.getParameterCount(statement); + + long memorySizeInBytes = statement.ramBytesUsed(); + + PreparedStatementMemoryManager.getInstance().allocate(statementName, memorySizeInBytes); + + PreparedStatementInfo info = + new PreparedStatementInfo(statementName, statement, memorySizeInBytes); + clientSession.addPreparedStatement(statementName, info); + + TSPrepareResp resp = new TSPrepareResp(RpcUtils.getStatus(TSStatusCode.SUCCESS_STATUS)); + resp.setParameterCount(parameterCount); + return resp; + } catch (Exception e) { + return new TSPrepareResp( + onQueryException( + e, OperationType.EXECUTE_STATEMENT.getName(), TSStatusCode.INTERNAL_SERVER_ERROR)); + } + } + + @Override + public TSExecuteStatementResp executePreparedStatement(TSExecutePreparedReq req) { + boolean finished = false; + long queryId = Long.MIN_VALUE; + IClientSession clientSession = SESSION_MANAGER.getCurrSessionAndUpdateIdleTime(); + + if (!SESSION_MANAGER.checkLogin(clientSession)) { + return RpcUtils.getTSExecuteStatementResp(getNotLoggedInStatus()); + } + + long startTime = System.nanoTime(); + Throwable t = null; + try { + String statementName = req.getStatementName(); + + PreparedStatementInfo preparedInfo = clientSession.getPreparedStatement(statementName); + if (preparedInfo == null) { + return RpcUtils.getTSExecuteStatementResp( + RpcUtils.getStatus( + TSStatusCode.EXECUTE_STATEMENT_ERROR, + String.format("Prepared statement '%s' does not exist", statementName))); + } + + List parameters = deserializeParameters(req.getParameters()); + + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement statement = + preparedInfo.getSql(); + + int expectedCount = ParameterExtractor.getParameterCount(statement); + if (parameters.size() != expectedCount) { + return RpcUtils.getTSExecuteStatementResp( + RpcUtils.getStatus( + TSStatusCode.EXECUTE_STATEMENT_ERROR, + String.format( + "Parameter count mismatch: expected %d, got %d", + expectedCount, parameters.size()))); + } + + // Request query ID + queryId = SESSION_MANAGER.requestQueryId(clientSession, null); + + // Execute using Coordinator with external parameters + long timeout = req.isSetTimeout() ? req.getTimeout() : config.getQueryTimeoutThreshold(); + ExecutionResult result = + COORDINATOR.executeForTableModel( + statement, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + "EXECUTE " + statementName, + metadata, + timeout, + true, + parameters); + + if (result.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode() + && result.status.code != TSStatusCode.REDIRECTION_RECOMMEND.getStatusCode()) { + finished = true; + return RpcUtils.getTSExecuteStatementResp(result.status); + } + + IQueryExecution queryExecution = COORDINATOR.getQueryExecution(queryId); + + try (SetThreadName threadName = new SetThreadName(result.queryId.getId())) { + TSExecuteStatementResp resp; + if (queryExecution != null && queryExecution.isQuery()) { + resp = createResponse(queryExecution.getDatasetHeader(), queryId); + resp.setStatus(result.status); + int fetchSize = + req.isSetFetchSize() ? req.getFetchSize() : config.getThriftMaxFrameSize(); + finished = setResultForPrepared.apply(resp, queryExecution, fetchSize); + resp.setMoreData(!finished); + } else { + finished = true; + resp = RpcUtils.getTSExecuteStatementResp(result.status); + } + return resp; + } + } catch (Exception e) { + finished = true; + t = e; + return RpcUtils.getTSExecuteStatementResp( + onQueryException( + e, OperationType.EXECUTE_STATEMENT.getName(), TSStatusCode.INTERNAL_SERVER_ERROR)); + } finally { + long currentOperationCost = System.nanoTime() - startTime; + if (finished) { + COORDINATOR.cleanupQueryExecution(queryId, null, t); + } + COORDINATOR.recordExecutionTime(queryId, currentOperationCost); + } + } + + @Override + public TSStatus deallocatePreparedStatement(TSDeallocatePreparedReq req) { + IClientSession clientSession = SESSION_MANAGER.getCurrSessionAndUpdateIdleTime(); + if (!SESSION_MANAGER.checkLogin(clientSession)) { + return getNotLoggedInStatus(); + } + + try { + String statementName = req.getStatementName(); + + PreparedStatementInfo removedInfo = clientSession.removePreparedStatement(statementName); + if (removedInfo == null) { + return RpcUtils.getStatus( + TSStatusCode.EXECUTE_STATEMENT_ERROR, + String.format("Prepared statement '%s' does not exist", statementName)); + } + + PreparedStatementMemoryManager.getInstance().release(removedInfo.getMemorySizeInBytes()); + + return RpcUtils.getStatus(TSStatusCode.SUCCESS_STATUS); + } catch (Exception e) { + return onQueryException( + e, OperationType.EXECUTE_STATEMENT.getName(), TSStatusCode.INTERNAL_SERVER_ERROR); + } + } + + private List deserializeParameters(List params) { + List literals = new ArrayList<>(); + for (ByteBuffer buf : params) { + buf.rewind(); + byte type = buf.get(); + switch (type) { + case 0x00: // Null + literals.add(new NullLiteral()); + break; + case 0x01: // Boolean + boolean boolVal = buf.get() != 0; + literals.add(new BooleanLiteral(boolVal ? "true" : "false")); + break; + case 0x02: // Long + long longVal = buf.getLong(); + literals.add(new LongLiteral(String.valueOf(longVal))); + break; + case 0x03: // Double + double doubleVal = buf.getDouble(); + literals.add(new DoubleLiteral(doubleVal)); + break; + case 0x04: // String + int strLen = buf.getInt(); + byte[] strBytes = new byte[strLen]; + buf.get(strBytes); + literals.add( + new StringLiteral(new String(strBytes, java.nio.charset.StandardCharsets.UTF_8))); + break; + case 0x05: // Binary + int binLen = buf.getInt(); + byte[] binBytes = new byte[binLen]; + buf.get(binBytes); + literals.add(new BinaryLiteral(binBytes)); + break; + default: + throw new IllegalArgumentException("Unknown parameter type: " + type); + } + } + return literals; + } + + private final SelectResult setResultForPrepared = + (resp, queryExecution, fetchSize) -> { + Pair pair = + QueryDataSetUtils.convertTsBlockByFetchSize(queryExecution, fetchSize); + resp.setQueryDataSet(pair.left); + return pair.right; + }; + + // ========================= End PreparedStatement RPC Methods ========================= + @Override public TSGetTimeZoneResp getTimeZone(long sessionId) { try { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java index 3210d277d861..87ae0c65539d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java @@ -404,6 +404,47 @@ public ExecutionResult executeForTableModel( Metadata metadata, long timeOut, boolean userQuery) { + // Delegate to overloaded version with empty parameters + return executeForTableModel( + statement, + sqlParser, + clientSession, + queryId, + session, + sql, + metadata, + timeOut, + userQuery, + Collections.emptyList()); + } + + /** + * Execute a table model statement with optional pre-bound parameters. Used by JDBC + * PreparedStatement to execute cached AST with serialized parameters. + * + * @param statement The AST to execute + * @param sqlParser SQL parser instance + * @param clientSession Current client session + * @param queryId Query ID + * @param session Session info + * @param sql SQL string for logging + * @param metadata Metadata instance + * @param timeOut Query timeout + * @param userQuery Whether this is a user query + * @param externalParameters List of Literal parameters to bind (empty for normal execution) + * @return ExecutionResult containing execution status and query ID + */ + public ExecutionResult executeForTableModel( + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement statement, + SqlParser sqlParser, + IClientSession clientSession, + long queryId, + SessionInfo session, + String sql, + Metadata metadata, + long timeOut, + boolean userQuery, + List externalParameters) { return execution( queryId, session, @@ -417,7 +458,8 @@ public ExecutionResult executeForTableModel( queryContext, metadata, timeOut > 0 ? timeOut : CONFIG.getQueryTimeoutThreshold(), - startTime))); + startTime, + externalParameters))); } public ExecutionResult executeForTableModel( @@ -481,7 +523,8 @@ private IQueryExecution createQueryExecutionForTableModel( final MPPQueryContext queryContext, final Metadata metadata, final long timeOut, - final long startTime) { + final long startTime, + final List externalParameters) { queryContext.setTimeOut(timeOut); queryContext.setStartTime(startTime); if (statement instanceof DropDB @@ -561,7 +604,11 @@ private IQueryExecution createQueryExecutionForTableModel( List parameters = Collections.emptyList(); Map, Expression> parameterLookup = Collections.emptyMap(); - if (statement instanceof Execute) { + // Handle external parameters from JDBC PreparedStatement (highest priority) + if (externalParameters != null && !externalParameters.isEmpty()) { + parameterLookup = ParameterExtractor.bindParameters(statement, externalParameters); + parameters = new ArrayList<>(externalParameters); + } else if (statement instanceof Execute) { Execute executeStatement = (Execute) statement; String statementName = executeStatement.getStatementName().getValue(); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreatePipe.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreatePipe.java index 1543e339fc86..269978e87bde 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreatePipe.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreatePipe.java @@ -115,9 +115,9 @@ public long ramBytesUsed() { long size = INSTANCE_SIZE; size += AstMemoryEstimationHelper.getEstimatedSizeOfNodeLocation(getLocationInternal()); size += RamUsageEstimator.sizeOf(pipeName); - size += RamUsageEstimator.sizeOfMap(extractorAttributes); + size += RamUsageEstimator.sizeOfMap(sourceAttributes); size += RamUsageEstimator.sizeOfMap(processorAttributes); - size += RamUsageEstimator.sizeOfMap(connectorAttributes); + size += RamUsageEstimator.sizeOfMap(sinkAttributes); return size; } } diff --git a/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift b/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift index 48afb89d3366..e751b9b4e713 100644 --- a/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift +++ b/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift @@ -167,6 +167,36 @@ struct TSCloseOperationReq { 4: optional string preparedStatementName } +// PreparedStatement - PREPARE +// Parses SQL and caches AST in session for later execution +struct TSPrepareReq { + 1: required i64 sessionId + 2: required string sql // SQL with ? placeholders + 3: required string statementName // Name to identify this prepared statement +} + +struct TSPrepareResp { + 1: required common.TSStatus status + 2: optional i32 parameterCount // Number of ? placeholders in SQL +} + +// PreparedStatement - EXECUTE +// Executes a prepared statement with bound parameters +struct TSExecutePreparedReq { + 1: required i64 sessionId + 2: required string statementName // Name of the prepared statement + 3: required list parameters // Serialized parameter values + 4: optional i32 fetchSize + 5: optional i64 timeout +} + +// PreparedStatement - DEALLOCATE +// Releases a prepared statement and its resources +struct TSDeallocatePreparedReq { + 1: required i64 sessionId + 2: required string statementName // Name of the prepared statement to release +} + struct TSFetchResultsReq{ 1: required i64 sessionId 2: required string statement @@ -576,6 +606,13 @@ service IClientRPCService { common.TSStatus closeOperation(1:TSCloseOperationReq req); + // PreparedStatement operations + TSPrepareResp prepareStatement(1:TSPrepareReq req); + + TSExecuteStatementResp executePreparedStatement(1:TSExecutePreparedReq req); + + common.TSStatus deallocatePreparedStatement(1:TSDeallocatePreparedReq req); + TSGetTimeZoneResp getTimeZone(1:i64 sessionId); common.TSStatus setTimeZone(1:TSSetTimeZoneReq req);