diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 0ec0b1e0..5201575b 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -21,9 +21,19 @@ jobs: strategy: fail-fast: false matrix: + clickhouse: [ 25.6, 25.7, 25.8, 25.9, latest ] java: [ 8, 17 ] - scala: [ 2.12, 2.13 ] - spark: [ 3.3, 3.4, 3.5 ] + scala: [ '2.12', '2.13' ] + spark: [ '3.3', '3.4', '3.5', '4.0' ] + exclude: + # Spark 4.0 only supports Scala 2.13 + - spark: '4.0' + scala: '2.12' + # Spark 4.0 requires Java 17+ + - spark: '4.0' + java: 8 + env: + CLICKHOUSE_IMAGE: clickhouse/clickhouse-server:${{ matrix.clickhouse }} steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 @@ -40,34 +50,7 @@ jobs: if: failure() uses: actions/upload-artifact@v4 with: - name: log-java-${{ matrix.java }}-spark-${{ matrix.spark }}-scala-${{ matrix.scala }} - path: | - **/build/unit-tests.log - log/** - - run-tests-with-specific-clickhouse: - runs-on: ubuntu-22.04 - strategy: - fail-fast: false - matrix: - clickhouse: [ 25.3, 25.6, 25.7, latest ] - env: - CLICKHOUSE_IMAGE: clickhouse/clickhouse-server:${{ matrix.clickhouse }} - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-java@v4 - with: - distribution: zulu - java-version: 8 - cache: gradle - - run: >- - ./gradlew clean test --no-daemon --refresh-dependencies - -PmavenCentralMirror=https://maven-central.storage-download.googleapis.com/maven2/ - - name: Upload logs - if: failure() - uses: actions/upload-artifact@v4 - with: - name: log-clickhouse-${{ matrix.clickhouse }} + name: log-ch-${{ matrix.clickhouse }}-java-${{ matrix.java }}-spark-${{ matrix.spark }}-scala-${{ matrix.scala }} path: | **/build/unit-tests.log log/** diff --git a/.github/workflows/check-license.yml b/.github/workflows/check-license.yml index f7fa9b29..c4e82350 100644 --- a/.github/workflows/check-license.yml +++ b/.github/workflows/check-license.yml @@ -29,13 +29,22 @@ jobs: strategy: fail-fast: false matrix: - spark: [ 3.3, 3.4, 3.5 ] + spark: [ "3.3", "3.4", "3.5", "4.0" ] + include: + - spark: "3.3" + java: 8 + - spark: "3.4" + java: 8 + - spark: "3.5" + java: 8 + - spark: "4.0" + java: 17 steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 with: distribution: zulu - java-version: 8 + java-version: ${{ matrix.java }} - run: >- ./gradlew rat --no-daemon -Dspark_binary_version=${{ matrix.spark }} diff --git a/.github/workflows/cloud.yml b/.github/workflows/cloud.yml index db74c52d..1271d60f 100644 --- a/.github/workflows/cloud.yml +++ b/.github/workflows/cloud.yml @@ -34,8 +34,16 @@ jobs: max-parallel: 1 fail-fast: false matrix: - spark: [ 3.3, 3.4, 3.5 ] - scala: [ 2.12, 2.13 ] + spark: [ '3.3', '3.4', '3.5', '4.0' ] + scala: [ '2.12', '2.13' ] + java: [ 8, 17 ] + exclude: + # Spark 4.0 only supports Scala 2.13 + - spark: '4.0' + scala: '2.12' + # Spark 4.0 requires Java 17+ + - spark: '4.0' + java: 8 env: CLICKHOUSE_CLOUD_HOST: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_HOST_SMT }} CLICKHOUSE_CLOUD_PASSWORD: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_PASSWORD_SMT }} @@ -44,7 +52,7 @@ jobs: - uses: actions/setup-java@v4 with: distribution: zulu - java-version: 8 + java-version: ${{ matrix.java }} cache: gradle - name: Wake up ClickHouse Cloud instance env: @@ -80,7 +88,7 @@ jobs: if: failure() uses: actions/upload-artifact@v4 with: - name: log-clickhouse-cloud-spark-${{ matrix.spark }}-scala-${{ matrix.scala }} + name: log-clickhouse-cloud-spark-${{ matrix.spark }}-scala-${{ matrix.scala }}-java-${{ matrix.java }} path: | **/build/unit-tests.log log/** diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 7a809fee..73d1ca13 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -30,13 +30,22 @@ jobs: strategy: fail-fast: false matrix: - spark: [ 3.3, 3.4, 3.5 ] + spark: [ "3.3", "3.4", "3.5", "4.0" ] + include: + - spark: "3.3" + java: 8 + - spark: "3.4" + java: 8 + - spark: "3.5" + java: 8 + - spark: "4.0" + java: 17 steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 with: distribution: zulu - java-version: 8 + java-version: ${{ matrix.java }} cache: gradle - run: >- ./gradlew spotlessCheck --no-daemon --refresh-dependencies diff --git a/.github/workflows/tpcds.yml b/.github/workflows/tpcds.yml index 0fcbe2df..6df01631 100644 --- a/.github/workflows/tpcds.yml +++ b/.github/workflows/tpcds.yml @@ -30,14 +30,22 @@ jobs: strategy: fail-fast: false matrix: - spark: [ 3.3, 3.4, 3.5 ] - scala: [ 2.12, 2.13 ] + spark: [ '3.3', '3.4', '3.5', '4.0' ] + scala: [ '2.12', '2.13' ] + java: [ 8, 17 ] + exclude: + # Spark 4.0 only supports Scala 2.13 + - spark: '4.0' + scala: '2.12' + # Spark 4.0 requires Java 17+ + - spark: '4.0' + java: 8 steps: - uses: actions/checkout@v4 - uses: actions/setup-java@v4 with: distribution: zulu - java-version: 8 + java-version: ${{ matrix.java }} cache: gradle - run: >- ./gradlew clean slowTest --no-daemon --refresh-dependencies @@ -48,7 +56,7 @@ jobs: if: failure() uses: actions/upload-artifact@v4 with: - name: log-tpcds-spark-${{ matrix.spark }}-scala-${{ matrix.scala }} + name: log-tpcds-spark-${{ matrix.spark }}-scala-${{ matrix.scala }}-java-${{ matrix.java }} path: | **/build/unit-tests.log log/** diff --git a/build.gradle b/build.gradle index 171988f6..6f0af275 100644 --- a/build.gradle +++ b/build.gradle @@ -49,8 +49,8 @@ project.ext { spark_prefix = "spark_${spark_binary_version.replace('.', '')}" scala_prefix = "scala_${scala_binary_version.replace('.', '')}" - scala_212_version = project.getProperty("${spark_prefix}_scala_212_version") - scala_213_version = project.getProperty("${spark_prefix}_scala_213_version") + scala_212_version = project.findProperty("${spark_prefix}_scala_212_version") ?: "2.12.18" + scala_213_version = project.findProperty("${spark_prefix}_scala_213_version") ?: "2.13.8" scala_version = project.getProperty("${scala_prefix}_version") antlr_version = project.getProperty("${spark_prefix}_antlr_version") @@ -106,7 +106,11 @@ allprojects { subprojects { apply plugin: "scala" apply plugin: "java-library" - apply plugin: "org.scoverage" + // Disable scoverage when running Metals' bloopInstall to avoid plugin resolution issues + def isBloopInstall = gradle.startParameter.taskNames.any { it.contains('bloopInstall') } + if (!project.hasProperty('disableScoverage') && !isBloopInstall) { + apply plugin: "org.scoverage" + } apply plugin: "com.diffplug.spotless" apply plugin: "com.github.maiflai.scalatest" @@ -168,11 +172,13 @@ subprojects { } } - scoverage { - scoverageVersion = "2.0.11" - reportDir.set(file("${rootProject.buildDir}/reports/scoverage")) - highlighting.set(false) - minimumRate.set(0.0) + if (plugins.hasPlugin('org.scoverage')) { + scoverage { + scoverageVersion = "2.0.11" + reportDir.set(file("${rootProject.buildDir}/reports/scoverage")) + highlighting.set(false) + minimumRate.set(0.0) + } } spotless { diff --git a/clickhouse-core/src/testFixtures/scala/com/clickhouse/spark/base/ClickHouseSingleMixIn.scala b/clickhouse-core/src/testFixtures/scala/com/clickhouse/spark/base/ClickHouseSingleMixIn.scala index b5682eb1..36bbb105 100644 --- a/clickhouse-core/src/testFixtures/scala/com/clickhouse/spark/base/ClickHouseSingleMixIn.scala +++ b/clickhouse-core/src/testFixtures/scala/com/clickhouse/spark/base/ClickHouseSingleMixIn.scala @@ -17,13 +17,18 @@ package com.clickhouse.spark.base import com.clickhouse.spark.Utils import com.clickhouse.data.ClickHouseVersion import com.dimafeng.testcontainers.{ForAllTestContainer, JdbcDatabaseContainer, SingleContainer} +import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite +import org.slf4j.LoggerFactory import org.testcontainers.containers.ClickHouseContainer import org.testcontainers.utility.{DockerImageName, MountableFile} import java.nio.file.{Path, Paths} import scala.collection.JavaConverters._ -trait ClickHouseSingleMixIn extends AnyFunSuite with ForAllTestContainer with ClickHouseProvider { +trait ClickHouseSingleMixIn extends AnyFunSuite with BeforeAndAfterAll with ForAllTestContainer + with ClickHouseProvider { + + private val logger = LoggerFactory.getLogger(getClass) // format: off private val CLICKHOUSE_IMAGE: String = Utils.load("CLICKHOUSE_IMAGE", "clickhouse/clickhouse-server:23.8") private val CLICKHOUSE_USER: String = Utils.load("CLICKHOUSE_USER", "default") @@ -34,6 +39,8 @@ trait ClickHouseSingleMixIn extends AnyFunSuite with ForAllTestContainer with Cl private val CLICKHOUSE_TPC_PORT = 9000 // format: on + logger.info(s"Initializing with ClickHouse image: $CLICKHOUSE_IMAGE") + override val clickhouseVersion: ClickHouseVersion = ClickHouseVersion.of(CLICKHOUSE_IMAGE.split(":").last) protected val rootProjectDir: Path = { @@ -80,4 +87,20 @@ trait ClickHouseSingleMixIn extends AnyFunSuite with ForAllTestContainer with Cl override def clickhousePassword: String = CLICKHOUSE_PASSWORD override def clickhouseDatabase: String = CLICKHOUSE_DB override def isSslEnabled: Boolean = false + + override def beforeAll(): Unit = { + val startTime = System.currentTimeMillis() + logger.info(s"Starting ClickHouse container: $CLICKHOUSE_IMAGE") + super.beforeAll() // This starts the container and makes mappedPort available + val duration = System.currentTimeMillis() - startTime + logger.info( + s"ClickHouse container started in ${duration}ms at ${container.host}:${container.mappedPort(CLICKHOUSE_HTTP_PORT)}" + ) + } + + override def afterAll(): Unit = { + logger.info("Stopping ClickHouse container") + super.afterAll() + logger.info("ClickHouse container stopped") + } } diff --git a/gradle.properties b/gradle.properties index dd9fbbe9..664dcc5c 100644 --- a/gradle.properties +++ b/gradle.properties @@ -16,10 +16,10 @@ mavenCentralMirror=https://repo1.maven.org/maven2/ mavenSnapshotsRepo=https://central.sonatype.com/repository/maven-snapshots/ mavenReleasesRepo=https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ -systemProp.scala_binary_version=2.12 +systemProp.scala_binary_version=2.13 systemProp.known_scala_binary_versions=2.12,2.13 -systemProp.spark_binary_version=3.5 -systemProp.known_spark_binary_versions=3.3,3.4,3.5 +systemProp.spark_binary_version=4.0 +systemProp.known_spark_binary_versions=3.3,3.4,3.5,4.0 group=com.clickhouse.spark @@ -29,6 +29,7 @@ clickhouse_client_v2_version=0.9.4 spark_33_version=3.3.4 spark_34_version=3.4.2 spark_35_version=3.5.1 +spark_40_version=4.0.1 spark_33_scala_212_version=2.12.15 spark_34_scala_212_version=2.12.17 @@ -37,22 +38,26 @@ spark_35_scala_212_version=2.12.18 spark_33_scala_213_version=2.13.8 spark_34_scala_213_version=2.13.8 spark_35_scala_213_version=2.13.8 +spark_40_scala_213_version=2.13.8 spark_33_antlr_version=4.8 spark_34_antlr_version=4.9.3 spark_35_antlr_version=4.9.3 +spark_40_antlr_version=4.13.1 spark_33_jackson_version=2.13.4 spark_34_jackson_version=2.14.2 spark_35_jackson_version=2.15.2 +spark_40_jackson_version=2.17.0 spark_33_slf4j_version=1.7.32 spark_34_slf4j_version=2.0.6 spark_35_slf4j_version=2.0.7 +spark_40_slf4j_version=2.0.7 # Align with Apache Spark, and don't bundle them in release jar. commons_lang3_version=3.12.0 -commons_codec_version=1.16.0 +commons_codec_version=1.17.2 # javax annotations removed in jdk 11 # fix build error with jakarta annotations @@ -61,5 +66,5 @@ jakarta_annotation_api_version=1.3.5 # Test only kyuubi_version=1.9.2 testcontainers_scala_version=0.41.2 -scalatest_version=3.2.16 +scalatest_version=3.2.19 flexmark_version=0.62.2 diff --git a/settings.gradle b/settings.gradle index 3bba3864..44811034 100644 --- a/settings.gradle +++ b/settings.gradle @@ -42,3 +42,8 @@ project(":clickhouse-spark-runtime-${spark_binary_version}_$scala_binary_version include ":clickhouse-spark-it-${spark_binary_version}_$scala_binary_version" project(":clickhouse-spark-it-${spark_binary_version}_$scala_binary_version").projectDir = file("spark-${spark_binary_version}/clickhouse-spark-it") project(":clickhouse-spark-it-${spark_binary_version}_$scala_binary_version").name = "clickhouse-spark-it-${spark_binary_version}_$scala_binary_version" + +// Examples module for running/debugging sample apps in IDE +include ":clickhouse-examples-${spark_binary_version}_$scala_binary_version" +project(":clickhouse-examples-${spark_binary_version}_$scala_binary_version").projectDir = file("spark-${spark_binary_version}/examples") +project(":clickhouse-examples-${spark_binary_version}_$scala_binary_version").name = "clickhouse-examples-${spark_binary_version}_$scala_binary_version" diff --git a/spark-4.0/build.gradle b/spark-4.0/build.gradle new file mode 100644 index 00000000..d1ef9144 --- /dev/null +++ b/spark-4.0/build.gradle @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +project.ext { + spark_version = project.getProperty("spark_40_version") + spark_binary_version = "4.0" +} + + +project(":clickhouse-spark-${spark_binary_version}_$scala_binary_version") { + dependencies { + api project(":clickhouse-core") + + compileOnly "org.apache.spark:spark-sql_$scala_binary_version:$spark_version" + + testImplementation "org.apache.spark:spark-sql_$scala_binary_version:$spark_version" + testImplementation "org.scalatest:scalatest_$scala_binary_version:$scalatest_version" + testRuntimeOnly "com.vladsch.flexmark:flexmark-all:$flexmark_version" + } +} + +project(":clickhouse-spark-runtime-${spark_binary_version}_$scala_binary_version") { + apply plugin: "com.github.johnrengelman.shadow" + + tasks.jar.dependsOn tasks.shadowJar + + dependencies { + compileOnly "org.scala-lang:scala-library:$scala_version" + + implementation(project(":clickhouse-spark-${spark_binary_version}_$scala_binary_version")) { + exclude group: "org.antlr", module: "antlr4-runtime" + exclude group: "org.scala-lang", module: "scala-library" + exclude group: "org.slf4j", module: "slf4j-api" + exclude group: "org.apache.commons", module: "commons-lang3" + exclude group: "com.clickhouse", module: "clickhouse-jdbc" + exclude group: "com.fasterxml.jackson.core" + exclude group: "com.fasterxml.jackson.datatype" + exclude group: "com.fasterxml.jackson.module" + } + } + + shadowJar { + zip64=true + archiveClassifier=null + + mergeServiceFiles() + } + + jar { + archiveClassifier="empty" + manifest { + attributes( + 'Implementation-Title': 'Spark-ClickHouse-Connector', + 'Implementation-Version': "${spark_binary_version}_${scala_binary_version}_${getProjectVersion()}" + ) + } + } +} + +project(":clickhouse-spark-it-${spark_binary_version}_$scala_binary_version") { + dependencies { + implementation "org.scala-lang:scala-library:$scala_version" // for scala plugin detect scala binary version + + testImplementation project(path: ":clickhouse-spark-runtime-${spark_binary_version}_$scala_binary_version", configuration: "shadow") + testImplementation(testFixtures(project(":clickhouse-core"))) { + exclude module: "clickhouse-core" + } + + testImplementation "org.apache.spark:spark-sql_$scala_binary_version:$spark_version" + + testImplementation "org.apache.spark:spark-core_$scala_binary_version:$spark_version:tests" + testImplementation "org.apache.spark:spark-catalyst_$scala_binary_version:$spark_version:tests" + testImplementation "org.apache.spark:spark-sql_$scala_binary_version:$spark_version:tests" + + testImplementation "com.fasterxml.jackson.datatype:jackson-datatype-jsr310:$jackson_version" + + testImplementation("com.clickhouse:clickhouse-jdbc:$clickhouse_jdbc_version:all") { transitive = false } + + testImplementation "org.apache.kyuubi:kyuubi-spark-connector-tpcds_${scala_binary_version}:$kyuubi_version" + } + + test { + classpath += files("${project(':clickhouse-core').projectDir}/src/testFixtures/conf") + } + + slowTest { + classpath += files("${project(':clickhouse-core').projectDir}/src/testFixtures/conf") + } + + cloudTest { + classpath += files("${project(':clickhouse-core').projectDir}/src/testFixtures/conf") + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/resources/log4j2.xml b/spark-4.0/clickhouse-spark-it/src/test/resources/log4j2.xml new file mode 100644 index 00000000..3e2579f1 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/resources/log4j2.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/SparkTest.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/SparkTest.scala new file mode 100644 index 00000000..dc914722 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/SparkTest.scala @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import org.apache.spark.SparkConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.{DataFrame, QueryTest} +import com.clickhouse.spark.ClickHouseCommandRunner + +import java.sql.{Date, Timestamp} +import java.time.Instant + +trait SparkTest extends QueryTest with SharedSparkSession { + + def cmdRunnerOptions: Map[String, String] + + /** + * @param text format yyyy-[m]m-[d]d + * @return A SQL Date + */ + def date(text: String): Date = Date.valueOf(text) + + /** + * @param text format 2007-12-03T10:15:30.00Z + * @return A SQL Timestamp + */ + def timestamp(text: String): Timestamp = Timestamp.from(Instant.parse(text)) + + override protected def sparkConf: SparkConf = super.sparkConf + .setMaster("local[2]") + .setAppName("spark-ut") + .set("spark.ui.enabled", "false") + .set("spark.driver.host", "localhost") + .set("spark.driver.memory", "500M") + .set("spark.sql.catalogImplementation", "in-memory") + .set("spark.sql.codegen.wholeStage", "false") + .set("spark.sql.shuffle.partitions", "2") + + def runClickHouseSQL(sql: String, options: Map[String, String] = cmdRunnerOptions): DataFrame = + spark.executeCommand(classOf[ClickHouseCommandRunner].getName, sql, options) + + def autoCleanupTable( + database: String, + table: String, + cleanup: Boolean = true + )(block: (String, String) => Unit): Unit = + try { + spark.sql(s"CREATE DATABASE IF NOT EXISTS `$database`") + block(database, table) + } finally if (cleanup) { + spark.sql(s"DROP TABLE IF EXISTS `$database`.`$table`") + spark.sql(s"DROP DATABASE IF EXISTS `$database` CASCADE") + } + + def withClickHouseSingleIdTable( + database: String, + table: String, + cleanup: Boolean = true + )(block: (String, String) => Unit): Unit = autoCleanupTable(database, table, cleanup) { (database, table) => + spark.sql( + s"""CREATE TABLE IF NOT EXISTS `$database`.`$table` ( + | id Long NOT NULL + |) USING ClickHouse + |TBLPROPERTIES ( + | engine = 'MergeTree()', + | order_by = 'id', + | settings.index_granularity = 8192 + |) + |""".stripMargin + ) + block(database, table) + } + + // for debugging webui + protected def infiniteLoop(): Unit = while (true) { + Thread.sleep(1000) + spark.catalog.listTables() + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/TPCDSTestUtils.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/TPCDSTestUtils.scala new file mode 100644 index 00000000..5f2925fa --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/TPCDSTestUtils.scala @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +object TPCDSTestUtils { + val tablePrimaryKeys: Map[String, Seq[String]] = Map( + "call_center" -> Array("cc_call_center_sk"), + "catalog_page" -> Array("cp_catalog_page_sk"), + "catalog_returns" -> Array("cr_item_sk", "cr_order_number"), + "catalog_sales" -> Array("cs_item_sk", "cs_order_number"), + "customer" -> Array("c_customer_sk"), + "customer_address" -> Array("ca_address_sk"), + "customer_demographics" -> Array("cd_demo_sk"), + "date_dim" -> Array("d_date_sk"), + "household_demographics" -> Array("hd_demo_sk"), + "income_band" -> Array("ib_income_band_sk"), + "inventory" -> Array("inv_date_sk", "inv_item_sk", "inv_warehouse_sk"), + "item" -> Array("i_item_sk"), + "promotion" -> Array("p_promo_sk"), + "reason" -> Array("r_reason_sk"), + "ship_mode" -> Array("sm_ship_mode_sk"), + "store" -> Array("s_store_sk"), + "store_returns" -> Array("sr_item_sk", "sr_ticket_number"), + "store_sales" -> Array("ss_item_sk", "ss_ticket_number"), + "time_dim" -> Array("t_time_sk"), + "warehouse" -> Array("w_warehouse_sk"), + "web_page" -> Array("wp_web_page_sk"), + "web_returns" -> Array("wr_item_sk", "wr_order_number"), + "web_sales" -> Array("ws_item_sk", "ws_order_number"), + "web_site" -> Array("web_site_sk") + ) +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/TestUtils.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/TestUtils.scala new file mode 100644 index 00000000..8107a884 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/TestUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} +import com.fasterxml.jackson.module.scala.ClassTagExtensions + +object TestUtils { + + @transient lazy val om: ObjectMapper with ClassTagExtensions = { + val _om = new ObjectMapper() with ClassTagExtensions + _om.findAndRegisterModules() + _om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + _om + } + + def toJson(value: Any): String = om.writeValueAsString(value) +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/BaseClusterWriteSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/BaseClusterWriteSuite.scala new file mode 100644 index 00000000..d2380668 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/BaseClusterWriteSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +import org.apache.spark.SparkConf +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +abstract class BaseClusterWriteSuite extends SparkClickHouseClusterTest { + + test("clickhouse write cluster") { + withSimpleDistTable("single_replica", "db_w", "t_dist", true) { (_, db, tbl_dist, tbl_local) => + val tblSchema = spark.table(s"$db.$tbl_dist").schema + assert(tblSchema == StructType( + StructField("create_time", DataTypes.TimestampType, nullable = false) :: + StructField("y", DataTypes.IntegerType, nullable = false) :: + StructField("m", DataTypes.IntegerType, nullable = false) :: + StructField("id", DataTypes.LongType, nullable = false) :: + StructField("value", DataTypes.StringType, nullable = true) :: Nil + )) + + checkAnswer( + spark + .table(s"$db.$tbl_dist") + .select("create_time", "y", "m", "id", "value"), + Seq( + Row(timestamp("2021-01-01T10:10:10Z"), 2021, 1, 1L, "1"), + Row(timestamp("2022-02-02T10:10:10Z"), 2022, 2, 2L, "2"), + Row(timestamp("2023-03-03T10:10:10Z"), 2023, 3, 3L, "3"), + Row(timestamp("2024-04-04T10:10:10Z"), 2024, 4, 4L, "4") + ) + ) + + checkAnswer( + spark.table(s"clickhouse_s1r1.$db.$tbl_local"), + Row(timestamp("2024-04-04T10:10:10Z"), 2024, 4, 4L, "4") :: Nil + ) + checkAnswer( + spark.table(s"clickhouse_s1r2.$db.$tbl_local"), + Row(timestamp("2021-01-01T10:10:10Z"), 2021, 1, 1L, "1") :: Nil + ) + checkAnswer( + spark.table(s"clickhouse_s2r1.$db.$tbl_local"), + Row(timestamp("2022-02-02T10:10:10Z"), 2022, 2, 2L, "2") :: Nil + ) + checkAnswer( + spark.table(s"clickhouse_s2r2.$db.$tbl_local"), + Row(timestamp("2023-03-03T10:10:10Z"), 2023, 3, 3L, "3") :: Nil + ) + } + } +} + +class ClusterNodesWriteSuite extends BaseClusterWriteSuite { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.clickhouse.write.write.repartitionNum", "0") + .set("spark.clickhouse.write.distributed.useClusterNodes", "true") + .set("spark.clickhouse.write.distributed.convertLocal", "false") +} + +class ConvertDistToLocalWriteSuite extends BaseClusterWriteSuite { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.clickhouse.write.write.repartitionNum", "0") + .set("spark.clickhouse.write.distributed.useClusterNodes", "true") + .set("spark.clickhouse.write.distributed.convertLocal", "true") +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClickHouseClusterHashUDFSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClickHouseClusterHashUDFSuite.scala new file mode 100644 index 00000000..d711ff4b --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClickHouseClusterHashUDFSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +import org.apache.spark.sql.clickhouse.TestUtils.om +import com.clickhouse.spark.func._ +import java.lang.{Long => JLong} + +class ClickHouseClusterHashUDFSuite extends SparkClickHouseClusterTest { + // only for query function names + val dummyRegistry: CompositeFunctionRegistry = { + val dynamicFunctionRegistry = new DynamicFunctionRegistry + val xxHash64ShardFunc = new ClickHouseXxHash64Shard(Seq.empty) + dynamicFunctionRegistry.register("ck_xx_hash64_shard", xxHash64ShardFunc) // for compatible + dynamicFunctionRegistry.register("clickhouse_shard_xxHash64", xxHash64ShardFunc) + new CompositeFunctionRegistry(Array(StaticFunctionRegistry, dynamicFunctionRegistry)) + } + + def runTest(sparkFuncName: String, ckFuncName: String, stringVal: String): Unit = { + val sparkResult = spark.sql( + s"SELECT $sparkFuncName($stringVal) AS hash_value" + ).collect + assert(sparkResult.length == 1) + val sparkHashVal = sparkResult.head.getAs[Long]("hash_value") + + val clickhouseResultJsonStr = runClickHouseSQL( + s"SELECT $ckFuncName($stringVal) AS hash_value " + ).head.getString(0) + val clickhouseResultJson = om.readTree(clickhouseResultJsonStr) + val clickhouseHashVal = JLong.parseUnsignedLong(clickhouseResultJson.get("hash_value").asText) + assert( + sparkHashVal == clickhouseHashVal, + s"ck_function: $ckFuncName, spark_function: $sparkFuncName, args: ($stringVal)" + ) + } + + Seq( + "clickhouse_xxHash64", + "clickhouse_murmurHash3_64", + "clickhouse_murmurHash3_32", + "clickhouse_murmurHash2_64", + "clickhouse_murmurHash2_32", + "clickhouse_cityHash64" + ).foreach { sparkFuncName => + val ckFuncName = dummyRegistry.sparkToClickHouseFunc(sparkFuncName) + test(s"UDF $sparkFuncName") { + Seq( + "spark-clickhouse-connector", + "Apache Spark", + "ClickHouse", + "Yandex", + "热爱", + "在传统的行式数据库系统中,数据按如下顺序存储:", + "🇨🇳" + ).map("'" + _ + "'").foreach { stringVal => + runTest(sparkFuncName, ckFuncName, stringVal) + } + } + } + + Seq( + "clickhouse_murmurHash3_64", + "clickhouse_murmurHash3_32", + "clickhouse_murmurHash2_64", + "clickhouse_murmurHash2_32", + "clickhouse_cityHash64" + ).foreach { sparkFuncName => + val ckFuncName = dummyRegistry.sparkToClickHouseFunc(sparkFuncName) + test(s"UDF $sparkFuncName multiple args") { + Seq( + "spark-clickhouse-connector", + "Apache Spark", + "ClickHouse", + "Yandex", + "热爱", + "在传统的行式数据库系统中,数据按如下顺序存储:", + "🇨🇳" + ).map("'" + _ + "'").combinations(5).foreach { seq => + val stringVal = seq.mkString(", ") + runTest(sparkFuncName, ckFuncName, stringVal) + } + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClickHouseClusterReadSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClickHouseClusterReadSuite.scala new file mode 100644 index 00000000..44fe1ff2 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClickHouseClusterReadSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.READ_DISTRIBUTED_CONVERT_LOCAL +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec + +class ClickHouseClusterReadSuite extends SparkClickHouseClusterTest { + + test("clickhouse metadata column - distributed table") { + withSimpleDistTable("single_replica", "db_w", "t_dist", true) { (_, db, tbl_dist, _) => + assert(READ_DISTRIBUTED_CONVERT_LOCAL.defaultValueString == "true") + + withSQLConf(READ_DISTRIBUTED_CONVERT_LOCAL.key -> "true") { + // `_shard_num` is dedicated for Distributed table + val cause = intercept[AnalysisException] { + spark.sql(s"SELECT y, _shard_num FROM $db.$tbl_dist") + } + assert(cause.message.contains("`_shard_num` cannot be resolved")) + } + + withSQLConf(READ_DISTRIBUTED_CONVERT_LOCAL.key -> "false") { + checkAnswer( + spark.sql(s"SELECT y, _shard_num FROM $db.$tbl_dist"), + Seq( + Row(2021, 2), + Row(2022, 3), + Row(2023, 4), + Row(2024, 1) + ) + ) + } + } + } + + test("push down aggregation - distributed table") { + withSimpleDistTable("single_replica", "db_agg_col", "t_dist", true) { (_, db, tbl_dist, _) => + checkAnswer( + spark.sql(s"SELECT COUNT(id) FROM $db.$tbl_dist"), + Seq(Row(4)) + ) + + checkAnswer( + spark.sql(s"SELECT MIN(id) FROM $db.$tbl_dist"), + Seq(Row(1)) + ) + + checkAnswer( + spark.sql(s"SELECT MAX(id) FROM $db.$tbl_dist"), + Seq(Row(4)) + ) + + checkAnswer( + spark.sql(s"SELECT m, COUNT(DISTINCT id) FROM $db.$tbl_dist GROUP BY m"), + Seq( + Row(1, 1), + Row(2, 1), + Row(3, 1), + Row(4, 1) + ) + ) + + checkAnswer( + spark.sql(s"SELECT m, SUM(DISTINCT id) FROM $db.$tbl_dist GROUP BY m"), + Seq( + Row(1, 1), + Row(2, 2), + Row(3, 3), + Row(4, 4) + ) + ) + } + } + + test("runtime filter - distributed table") { + withSimpleDistTable("single_replica", "runtime_db", "runtime_tbl", true) { (_, db, tbl_dist, _) => + spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=false") + checkAnswer( + spark.sql(s"SELECT id FROM $db.$tbl_dist " + + s"WHERE id IN (" + + s" SELECT id FROM $db.$tbl_dist " + + s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" + + s")"), + Row(1) + ) + + spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=true") + val df = spark.sql(s"SELECT id FROM $db.$tbl_dist " + + s"WHERE id IN (" + + s" SELECT id FROM $db.$tbl_dist " + + s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" + + s")") + checkAnswer(df, Row(1)) + val runtimeFilterExists = df.queryExecution.sparkPlan.exists { + case BatchScanExec(_, _, runtimeFilters, _, table, _) + if table.name() == TableIdentifier(tbl_dist, Some(db)).quotedString + && runtimeFilters.nonEmpty => true + case _ => false + } + assert(runtimeFilterExists) + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterDeleteSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterDeleteSuite.scala new file mode 100644 index 00000000..a5d7d0e4 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterDeleteSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +class ClusterDeleteSuite extends SparkClickHouseClusterTest { + + test("truncate distribute table") { + withSimpleDistTable("single_replica", "db_truncate", "tbl_truncate", true) { (_, db, tbl_dist, _) => + assert(spark.table(s"$db.$tbl_dist").count() === 4) + spark.sql(s"TRUNCATE TABLE $db.$tbl_dist") + assert(spark.table(s"$db.$tbl_dist").count() === 0) + } + } + + test("delete from distribute table") { + withSimpleDistTable("single_replica", "db_delete", "tbl_delete", true) { (_, db, tbl_dist, _) => + assert(spark.table(s"$db.$tbl_dist").count() === 4) + spark.sql(s"DELETE FROM $db.$tbl_dist WHERE m = 1") + assert(spark.table(s"$db.$tbl_dist").count() === 3) + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterPartitionManagementSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterPartitionManagementSuite.scala new file mode 100644 index 00000000..63da1075 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterPartitionManagementSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +import org.apache.spark.sql.Row + +class ClusterPartitionManagementSuite extends SparkClickHouseClusterTest { + + test("distribute table partition") { + withSimpleDistTable("single_replica", "db_part", "tbl_part", true) { (_, db, tbl_dist, _) => + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl_dist"), + Seq(Row("m=1"), Row("m=2"), Row("m=3"), Row("m=4")) + ) + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl_dist PARTITION(m = 2)"), + Seq(Row("m=2")) + ) + spark.sql(s"ALTER TABLE $db.$tbl_dist DROP PARTITION(m = 2)") + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl_dist"), + Seq(Row("m=1"), Row("m=3"), Row("m=4")) + ) + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterShardByRandSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterShardByRandSuite.scala new file mode 100644 index 00000000..bade6e91 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterShardByRandSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +import org.apache.spark.sql.Row + +class ClusterShardByRandSuite extends SparkClickHouseClusterTest { + + test("shard by rand()") { + val cluster = "single_replica" + val db = "db_rand_shard" + val tbl_dist = "tbl_rand_shard" + val tbl_local = s"${tbl_dist}_local" + + try { + runClickHouseSQL(s"CREATE DATABASE IF NOT EXISTS $db ON CLUSTER $cluster") + + spark.sql( + s"""CREATE TABLE $db.$tbl_local ( + | create_time TIMESTAMP NOT NULL, + | value STRING + |) USING ClickHouse + |TBLPROPERTIES ( + | cluster = '$cluster', + | engine = 'MergeTree()', + | order_by = 'create_time' + |) + |""".stripMargin + ) + + runClickHouseSQL( + s"""CREATE TABLE $db.$tbl_dist ON CLUSTER $cluster + |AS $db.$tbl_local + |ENGINE = Distributed($cluster, '$db', '$tbl_local', rand()) + |""".stripMargin + ) + spark.sql( + s"""INSERT INTO `$db`.`$tbl_dist` + |VALUES + | (timestamp'2021-01-01 10:10:10', '1'), + | (timestamp'2022-02-02 10:10:10', '2'), + | (timestamp'2023-03-03 10:10:10', '3'), + | (timestamp'2024-04-04 10:10:10', '4') AS tab(create_time, value) + |""".stripMargin + ) + checkAnswer( + spark.table(s"$db.$tbl_dist").select("value").orderBy("create_time"), + Seq(Row("1"), Row("2"), Row("3"), Row("4")) + ) + } finally { + runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl_dist ON CLUSTER $cluster") + runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl_local ON CLUSTER $cluster") + runClickHouseSQL(s"DROP DATABASE IF EXISTS $db ON CLUSTER $cluster") + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterTableManagementSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterTableManagementSuite.scala new file mode 100644 index 00000000..7096160d --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterTableManagementSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +class ClusterTableManagementSuite extends SparkClickHouseClusterTest { + + test("create or replace distribute table") { + autoCleanupDistTable("single_replica", "db_cor", "tbl_cor_dist") { (cluster, db, _, tbl_local) => + def createLocalTable(): Unit = spark.sql( + s"""CREATE TABLE $db.$tbl_local ( + | id Long NOT NULL + |) USING ClickHouse + |TBLPROPERTIES ( + | cluster = '$cluster', + | engine = 'MergeTree()', + | order_by = 'id', + | settings.index_granularity = 8192 + |) + |""".stripMargin + ) + + def createOrReplaceLocalTable(): Unit = spark.sql( + s"""CREATE OR REPLACE TABLE `$db`.`$tbl_local` ( + | id Long NOT NULL + |) USING ClickHouse + |TBLPROPERTIES ( + | engine = 'MergeTree()', + | order_by = 'id', + | settings.index_granularity = 8192 + |) + |""".stripMargin + ) + createLocalTable() + createOrReplaceLocalTable() + createOrReplaceLocalTable() + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/SparkClickHouseClusterTest.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/SparkClickHouseClusterTest.scala new file mode 100644 index 00000000..bc91abe4 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/SparkClickHouseClusterTest.scala @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +import com.clickhouse.spark.base.ClickHouseClusterMixIn +import org.apache.spark.SparkConf +import org.apache.spark.sql.clickhouse.SparkTest +import org.apache.spark.sql.functions.{month, year} + +trait SparkClickHouseClusterTest extends SparkTest with ClickHouseClusterMixIn { + + import testImplicits._ + + override protected def sparkConf: SparkConf = super.sparkConf + .setMaster("local[4]") + .setAppName("spark-clickhouse-cluster-ut") + .set("spark.sql.shuffle.partitions", "4") + // catalog + .set("spark.sql.defaultCatalog", "clickhouse_s1r1") + .set("spark.sql.catalog.clickhouse_s1r1", "com.clickhouse.spark.ClickHouseCatalog") + .set("spark.sql.catalog.clickhouse_s1r1.host", clickhouse_s1r1_host) + .set("spark.sql.catalog.clickhouse_s1r1.http_port", clickhouse_s1r1_http_port.toString) + .set("spark.sql.catalog.clickhouse_s1r1.protocol", "http") + .set("spark.sql.catalog.clickhouse_s1r1.user", "default") + .set("spark.sql.catalog.clickhouse_s1r1.password", "") + .set("spark.sql.catalog.clickhouse_s1r1.database", "default") + .set("spark.sql.catalog.clickhouse_s1r1.option.custom_http_params", "async_insert=1,wait_for_async_insert=1") + .set("spark.sql.catalog.clickhouse_s1r2", "com.clickhouse.spark.ClickHouseCatalog") + .set("spark.sql.catalog.clickhouse_s1r2.host", clickhouse_s1r2_host) + .set("spark.sql.catalog.clickhouse_s1r2.http_port", clickhouse_s1r2_http_port.toString) + .set("spark.sql.catalog.clickhouse_s1r2.protocol", "http") + .set("spark.sql.catalog.clickhouse_s1r2.user", "default") + .set("spark.sql.catalog.clickhouse_s1r2.password", "") + .set("spark.sql.catalog.clickhouse_s1r2.database", "default") + .set("spark.sql.catalog.clickhouse_s1r2.option.custom_http_params", "async_insert=1,wait_for_async_insert=1") + .set("spark.sql.catalog.clickhouse_s2r1", "com.clickhouse.spark.ClickHouseCatalog") + .set("spark.sql.catalog.clickhouse_s2r1.host", clickhouse_s2r1_host) + .set("spark.sql.catalog.clickhouse_s2r1.http_port", clickhouse_s2r1_http_port.toString) + .set("spark.sql.catalog.clickhouse_s2r1.protocol", "http") + .set("spark.sql.catalog.clickhouse_s2r1.user", "default") + .set("spark.sql.catalog.clickhouse_s2r1.password", "") + .set("spark.sql.catalog.clickhouse_s2r1.database", "default") + .set("spark.sql.catalog.clickhouse_s2r1.option.custom_http_params", "async_insert=1,wait_for_async_insert=1") + .set("spark.sql.catalog.clickhouse_s2r2", "com.clickhouse.spark.ClickHouseCatalog") + .set("spark.sql.catalog.clickhouse_s2r2.host", clickhouse_s2r2_host) + .set("spark.sql.catalog.clickhouse_s2r2.http_port", clickhouse_s2r2_http_port.toString) + .set("spark.sql.catalog.clickhouse_s2r2.protocol", "http") + .set("spark.sql.catalog.clickhouse_s2r2.user", "default") + .set("spark.sql.catalog.clickhouse_s2r2.password", "") + .set("spark.sql.catalog.clickhouse_s2r2.database", "default") + .set("spark.sql.catalog.clickhouse_s2r2.option.custom_http_params", "async_insert=1,wait_for_async_insert=1") + // extended configurations + .set("spark.clickhouse.write.batchSize", "2") + .set("spark.clickhouse.write.maxRetry", "2") + .set("spark.clickhouse.write.retryInterval", "1") + .set("spark.clickhouse.write.retryableErrorCodes", "241") + .set("spark.clickhouse.write.write.repartitionNum", "0") + .set("spark.clickhouse.write.distributed.useClusterNodes", "true") + .set("spark.clickhouse.read.distributed.useClusterNodes", "false") + .set("spark.clickhouse.write.distributed.convertLocal", "false") + .set("spark.clickhouse.read.distributed.convertLocal", "true") + .set("spark.clickhouse.read.format", "binary") + .set("spark.clickhouse.write.format", "arrow") + + override def cmdRunnerOptions: Map[String, String] = Map( + "host" -> clickhouse_s1r1_host, + "http_port" -> clickhouse_s1r1_http_port.toString, + "protocol" -> "http", + "user" -> "default", + "password" -> "", + "database" -> "default" + ) + + def autoCleanupDistTable( + cluster: String, + db: String, + tbl_dist: String + )(f: (String, String, String, String) => Unit): Unit = { + val tbl_local = s"${tbl_dist}_local" + try { + runClickHouseSQL(s"CREATE DATABASE IF NOT EXISTS $db ON CLUSTER $cluster") + f(cluster, db, tbl_dist, tbl_local) + } finally { + runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl_dist ON CLUSTER $cluster") + runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl_local ON CLUSTER $cluster") + runClickHouseSQL(s"DROP DATABASE IF EXISTS $db ON CLUSTER $cluster") + } + } + + def withSimpleDistTable( + cluster: String, + db: String, + tbl_dist: String, + writeData: Boolean = false + )(f: (String, String, String, String) => Unit): Unit = + autoCleanupDistTable(cluster, db, tbl_dist) { (cluster, db, tbl_dist, tbl_local) => + spark.sql( + s"""CREATE TABLE $db.$tbl_dist ( + | create_time TIMESTAMP NOT NULL, + | y INT NOT NULL COMMENT 'shard key', + | m INT NOT NULL COMMENT 'part key', + | id BIGINT NOT NULL COMMENT 'sort key', + | value STRING + |) USING ClickHouse + |PARTITIONED BY (m) + |TBLPROPERTIES ( + | cluster = '$cluster', + | engine = 'Distributed', + | shard_by = 'y', + | local.engine = 'MergeTree()', + | local.database = '$db', + | local.table = '$tbl_local', + | local.order_by = 'id', + | local.settings.index_granularity = 8192 + |) + |""".stripMargin + ) + Thread.sleep(3000) + if (writeData) { + val tblSchema = spark.table(s"$db.$tbl_dist").schema + val dataDF = spark.createDataFrame(Seq( + (timestamp("2021-01-01T10:10:10Z"), 1L, "1"), + (timestamp("2022-02-02T10:10:10Z"), 2L, "2"), + (timestamp("2023-03-03T10:10:10Z"), 3L, "3"), + (timestamp("2024-04-04T10:10:10Z"), 4L, "4") + )).toDF("create_time", "id", "value") + .withColumn("y", year($"create_time")) + .withColumn("m", month($"create_time")) + .select($"create_time", $"y", $"m", $"id", $"value") + + spark.createDataFrame(dataDF.rdd, tblSchema) + .writeTo(s"$db.$tbl_dist") + .append + } + f(cluster, db, tbl_dist, tbl_local) + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/TPCDSClusterSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/TPCDSClusterSuite.scala new file mode 100644 index 00000000..a50506d1 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/TPCDSClusterSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.cluster + +import org.apache.spark.SparkConf +import org.apache.spark.sql.clickhouse.TPCDSTestUtils +import org.scalatest.tags.Slow + +@Slow +class TPCDSClusterSuite extends SparkClickHouseClusterTest { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.catalog.tpcds", "org.apache.kyuubi.spark.connector.tpcds.TPCDSCatalog") + .set("spark.sql.catalog.clickhouse_s1r1.protocol", "http") + .set("spark.sql.catalog.clickhouse_s1r2.protocol", "http") + .set("spark.sql.catalog.clickhouse_s2r1.protocol", "http") + .set("spark.sql.catalog.clickhouse_s2r2.protocol", "http") + .set("spark.clickhouse.read.compression.codec", "lz4") + .set("spark.clickhouse.write.batchSize", "100000") + .set("spark.clickhouse.write.compression.codec", "lz4") + .set("spark.clickhouse.write.distributed.convertLocal", "true") + .set("spark.clickhouse.write.format", "json") + + test("Cluster: TPC-DS sf1 write and count(*)") { + withDatabase("tpcds_sf1_cluster") { + spark.sql("CREATE DATABASE tpcds_sf1_cluster WITH DBPROPERTIES (cluster = 'single_replica')") + + TPCDSTestUtils.tablePrimaryKeys.foreach { case (table, primaryKeys) => + println(s"before table ${table} ${primaryKeys}") + val start: Long = System.currentTimeMillis() + spark.sql( + s""" + |CREATE TABLE tpcds_sf1_cluster.$table + |USING clickhouse + |TBLPROPERTIES ( + | cluster = 'single_replica', + | engine = 'distributed', + | 'local.order_by' = '${primaryKeys.mkString(",")}', + | 'local.settings.allow_nullable_key' = 1 + |) + |SELECT * FROM tpcds.sf1.$table; + |""".stripMargin + ) + println(s"time took table ${table} ${System.currentTimeMillis() - start}") + } + + TPCDSTestUtils.tablePrimaryKeys.keys.foreach { table => + println(s"table ${table}") + assert(spark.table(s"tpcds.sf1.$table").count === spark.table(s"tpcds_sf1_cluster.$table").count) + } + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseArrowWriterSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseArrowWriterSuite.scala new file mode 100644 index 00000000..721ba948 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseArrowWriterSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.ClickHouseSingleMixIn +import org.apache.spark.SparkConf + +class ClickHouseSingleArrowWriterSuite extends ClickHouseArrowWriterSuite with ClickHouseSingleMixIn + +abstract class ClickHouseArrowWriterSuite extends ClickHouseWriterTestBase { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.clickhouse.write.format", "arrow") + .set("spark.clickhouse.read.format", "json") + +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseBinaryReaderSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseBinaryReaderSuite.scala new file mode 100644 index 00000000..decfe0af --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseBinaryReaderSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.{ClickHouseCloudMixIn, ClickHouseSingleMixIn} +import org.apache.spark.SparkConf +import org.scalatest.tags.Cloud + +@Cloud +class ClickHouseCloudBinaryReaderSuite extends ClickHouseBinaryReaderSuite with ClickHouseCloudMixIn + +class ClickHouseSingleBinaryReaderSuite extends ClickHouseBinaryReaderSuite with ClickHouseSingleMixIn + +/** + * Test suite for ClickHouse Binary Reader. + * Uses binary format for reading data from ClickHouse. + * All test cases are inherited from ClickHouseReaderTestBase. + */ +abstract class ClickHouseBinaryReaderSuite extends ClickHouseReaderTestBase { + + // Override to use binary format for reading + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.clickhouse.read.format", "binary") + .set("spark.clickhouse.write.format", "arrow") + + // All tests are inherited from ClickHouseReaderTestBase + // Additional binary-specific tests can be added here if needed +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseDataTypeSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseDataTypeSuite.scala new file mode 100644 index 00000000..62f6a15d --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseDataTypeSuite.scala @@ -0,0 +1,197 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.{ClickHouseCloudMixIn, ClickHouseSingleMixIn} +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.USE_NULLABLE_QUERY_SCHEMA +import org.apache.spark.sql.clickhouse.SparkUtils +import org.apache.spark.sql.types.DataTypes.{createArrayType, createMapType} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} +import org.scalatest.tags.Cloud + +import java.math.MathContext + +@Cloud +class ClickHouseCloudDataTypeSuite extends ClickHouseDataTypeSuite with ClickHouseCloudMixIn + +class ClickHouseSingleDataTypeSuite extends ClickHouseDataTypeSuite with ClickHouseSingleMixIn + +abstract class ClickHouseDataTypeSuite extends SparkClickHouseSingleTest { + + val SPARK_43390_ENABLED: Boolean = sys.env.contains("SPARK_43390_ENABLED") || { + SparkUtils.MAJOR_MINOR_VERSION match { + case (major, _) if major > 3 => true + case (3, minor) if minor > 4 => true + case _ => false + } + } + + test("write supported data types") { + val schema = StructType( + StructField("id", LongType, false) :: + StructField("col_string", StringType, false) :: + StructField("col_date", DateType, false) :: + StructField("col_array_string", createArrayType(StringType, false), false) :: + StructField("col_map_string_string", createMapType(StringType, StringType, false), false) :: + Nil + ) + val db = "t_w_s_db" + val tbl = "t_w_s_tbl" + withTable(db, tbl, schema) { + val tblSchema = spark.table(s"$db.$tbl").schema + val respectNullable = SPARK_43390_ENABLED && !spark.conf.get(USE_NULLABLE_QUERY_SCHEMA) + if (respectNullable) { + // TODO nested field does not respect nullable + // assert(StructType(schema) === tblSchema) + } else { + val nullableFields = + schema.fields.map(structField => structField.copy(dataType = structField.dataType.asNullable)) + assert(StructType(nullableFields) === tblSchema) + } + + val dataDF = spark.createDataFrame(Seq( + (1L, "a", date("1996-06-06"), Seq("a", "b", "c"), Map("a" -> "x")), + (2L, "A", date("2022-04-12"), Seq("A", "B", "C"), Map("A" -> "X")) + )).toDF("id", "col_string", "col_date", "col_array_string", "col_map_string_string") + + spark.createDataFrame(dataDF.rdd, tblSchema) + .writeTo(s"$db.$tbl") + .append + + checkAnswer( + spark.table(s"$db.$tbl").sort("id"), + Row(1L, "a", date("1996-06-06"), Seq("a", "b", "c"), Map("a" -> "x")) :: + Row(2L, "A", date("2022-04-12"), Seq("A", "B", "C"), Map("A" -> "X")) :: Nil + ) + } + } + + // "allow_experimental_bigint_types" setting is removed since v21.7.1.7020-testing + // https://github.com/ClickHouse/ClickHouse/pull/24812 + val BIGINT_TYPES: Seq[String] = Seq("Int128", "UInt128", "Int256", "UInt256") + + // TODO - Supply more test cases + // 1. data type alias + // 2. negative cases + // 3. unsupported integer types + Seq( + ("Int8", -128.toByte, 127.toByte), + ("UInt8", 0.toShort, 255.toShort), + ("Int16", -32768.toShort, 32767.toShort), + ("UInt16", 0, 65535), + ("Int32", -2147483648, 2147483647), + ("UInt32", 0L, 4294967295L), + ("Int64", -9223372036854775808L, 9223372036854775807L), + // Only overlapping value range of both the ClickHouse type and the Spark type is supported + ("UInt64", 0L, 4294967295L), + ("Int128", BigDecimal("-" + "9" * 38), BigDecimal("9" * 38)), + ("UInt128", BigDecimal(0), BigDecimal("9" * 38)), + ("Int256", BigDecimal("-" + "9" * 38), BigDecimal("9" * 38)), + ("UInt256", BigDecimal(0), BigDecimal("9" * 38)) + ).foreach { case (dataType, lower, upper) => + test(s"DateType - $dataType") { + if (BIGINT_TYPES.contains(dataType)) { + assume(clickhouseVersion.isNewerOrEqualTo("21.7.1.7020")) + } + testDataType(dataType) { (db, tbl) => + runClickHouseSQL( + s"""INSERT INTO $db.$tbl VALUES + |(1, $lower), + |(2, $upper) + |""".stripMargin + ) + } { df => + checkAnswer( + df, + Row(1, lower) :: Row(2, upper) :: Nil + ) + checkAnswer( + df.filter("value > 1"), + Row(2, upper) :: Nil + ) + } + } + } + + test("DataType - DateTime") { + testDataType("DateTime") { (db, tbl) => + runClickHouseSQL( + s"""INSERT INTO $db.$tbl VALUES + |(1, '2021-01-01 01:01:01'), + |(2, '2022-02-02 02:02:02') + |""".stripMargin + ) + } { df => + checkAnswer( + df, + Row(1, timestamp("2021-01-01T01:01:01Z")) :: + Row(2, timestamp("2022-02-02T02:02:02Z")) :: Nil + ) + checkAnswer( + df.filter("value > '2022-01-01 01:01:01'"), + Row(2, timestamp("2022-02-02T02:02:02Z")) :: Nil + ) + } + } + + // Decimal(P, S): P - precision, S - scale, which have different support range in Spark and ClickHouse. + // + // Spark: + // Decimal(P, S): P: [ 1:38]; S: [0:P] + // ClickHouse: + // Decimal(P, S): P: [ 1:76]; S: [0:P] + // Decimal32(S): P: [ 1: 9]; S: [0:P] + // Decimal64(S): P: [10:18]; S: [0:P] + // Decimal128(S): P: [19:38]; S: [0:P] + // Decimal256(S): P: [39:76]; S: [0:P] + Seq( + ("Decimal(38,9)", 38, 9), + ("Decimal32(4)", 9, 4), + ("Decimal64(4)", 18, 4), + ("Decimal128(4)", 38, 4) + ).foreach { case (dataType, p, s) => + test(s"DataType - $dataType") { + testDataType(dataType) { (db, tbl) => + runClickHouseSQL( + s"""INSERT INTO $db.$tbl VALUES + |(1, '11.1') + |""".stripMargin + ) + } { df => + assert(df.schema.length === 2) + assert(df.schema.fields(1).dataType === DecimalType(p, s)) + checkAnswer( + df, + Row(1, BigDecimal("11.1", new MathContext(p))) :: Nil + ) + } + } + } + + private def testDataType(valueColDef: String)(prepare: (String, String) => Unit)(validate: DataFrame => Unit) + : Unit = { + val db = "test_kv_db" + val tbl = "test_kv_tbl" + if (!clickhouseVersion.isNewerOrEqualTo("23.3") || isCloud) { + Thread.sleep(1000) + } + withKVTable(db, tbl, valueColDef = valueColDef) { + prepare(db, tbl) + val df = spark.sql(s"SELECT key, value FROM $db.$tbl ORDER BY key") + validate(df) + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseGenericSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseGenericSuite.scala new file mode 100644 index 00000000..74725190 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseGenericSuite.scala @@ -0,0 +1,495 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.{ClickHouseCloudMixIn, ClickHouseSingleMixIn} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.types._ +import org.scalatest.tags.Cloud + +@Cloud +class ClickHouseCloudGenericSuite extends ClickHouseDataTypeSuite with ClickHouseCloudMixIn + +class ClickHouseSingleGenericSuite extends ClickHouseDataTypeSuite with ClickHouseSingleMixIn + +abstract class ClickHouseGenericSuite extends SparkClickHouseSingleTest { + + import testImplicits._ + + test("clickhouse command runner") { + checkAnswer( + runClickHouseSQL("SELECT visibleWidth(NULL)"), + Row("""{"visibleWidth(NULL)":"4"}""") :: Nil + ) + } + + test("clickhouse catalog") { + withDatabase("db_t1", "db_t2") { + spark.sql("CREATE DATABASE db_t1") + spark.sql("CREATE DATABASE db_t2") + checkAnswer( + spark.sql("SHOW DATABASES LIKE 'db_t*'"), + Row("db_t1") :: Row("db_t2") :: Nil + ) + spark.sql("USE system") + checkAnswer( + spark.sql("SELECT current_database()"), + Row("system") :: Nil + ) + assert(spark.sql("SHOW tables").where($"tableName" === "contributors").count === 1) + } + } + + test("clickhouse system table") { + checkAnswer( + spark.sql("SELECT time_zone FROM `system`.`time_zones` WHERE time_zone = 'Asia/Shanghai'"), + Row("Asia/Shanghai") :: Nil + ) + } + + test("clickhouse partition") { + val db = "db_part" + val tbl = "tbl_part" + + // DROP + PURGE + withSimpleTable(db, tbl, true) { + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq(Row("m=1"), Row("m=2")) + ) + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl PARTITION(m = 2)"), + Seq(Row("m=2")) + ) + + spark.sql(s"ALTER TABLE $db.$tbl DROP PARTITION(m = 2)") + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq(Row("m=1")) + ) + + spark.sql(s"ALTER TABLE $db.$tbl DROP PARTITION(m = 1) PURGE") + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq() + ) + } + + // DROP + TRUNCATE + withSimpleTable(db, tbl, true) { + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq(Row("m=1"), Row("m=2")) + ) + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl PARTITION(m = 2)"), + Seq(Row("m=2")) + ) + + spark.sql(s"ALTER TABLE $db.$tbl DROP PARTITION(m = 2)") + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq(Row("m=1")) + ) + + spark.sql(s"TRUNCATE TABLE $db.$tbl PARTITION(m = 1)") + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq() + ) + } + } + + test("clickhouse partition (date type)") { + val db = "db_part_date" + val tbl = "tbl_part_date" + val schema = + StructType( + StructField("id", LongType, false) :: + StructField("date", DateType, false) :: Nil + ) + withTable(db, tbl, schema, partKeys = Seq("date")) { + spark.sql( + s"""INSERT INTO `$db`.`$tbl` + |VALUES + | (11L, "2022-04-11"), + | (12L, "2022-04-12") AS tab(id, date) + |""".stripMargin + ) + spark.createDataFrame(Seq( + (21L, date("2022-04-21")), + (22L, date("2022-04-22")) + )) + .toDF("id", "date") + .writeTo(s"$db.$tbl").append + + checkAnswer( + spark.table(s"$db.$tbl").orderBy($"id"), + Row(11L, date("2022-04-11")) :: + Row(12L, date("2022-04-12")) :: + Row(21L, date("2022-04-21")) :: + Row(22L, date("2022-04-22")) :: Nil + ) + + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq( + Row("date=2022-04-11"), + Row("date=2022-04-12"), + Row("date=2022-04-21"), + Row("date=2022-04-22") + ) + ) + } + } + + test("clickhouse multi part columns") { + val db = "db_multi_part_col" + val tbl = "tbl_multi_part_col" + val schema = + StructType( + StructField("id", LongType, false) :: + StructField("value", StringType, false) :: + StructField("part_1", StringType, false) :: + StructField("part_2", IntegerType, false) :: Nil + ) + withTable(db, tbl, schema, partKeys = Seq("part_1", "part_2")) { + spark.sql( + s"""INSERT INTO `$db`.`$tbl` + |VALUES + | (11L, 'one_one', '1', 1), + | (12L, 'one_two', '1', 2) AS tab(id, value, part_1, part_2) + |""".stripMargin + ) + + spark.createDataFrame(Seq( + (21L, "two_one", "2", 1), + (22L, "two_two", "2", 2) + )) + .toDF("id", "value", "part_1", "part_2") + .writeTo(s"$db.$tbl").append + + checkAnswer( + spark.table(s"$db.$tbl").orderBy($"id"), + Row(11L, "one_one", "1", 1) :: + Row(12L, "one_two", "1", 2) :: + Row(21L, "two_one", "2", 1) :: + Row(22L, "two_two", "2", 2) :: Nil + ) + + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq( + Row("part_1=1/part_2=1"), + Row("part_1=1/part_2=2"), + Row("part_1=2/part_2=1"), + Row("part_1=2/part_2=2") + ) + ) + } + } + + test("clickhouse multi part columns (date type)") { + val db = "db_mul_part_date" + val tbl = "tbl_mul_part_date" + val schema = + StructType( + StructField("id", LongType, false) :: + StructField("part_1", DateType, false) :: + StructField("part_2", IntegerType, false) :: Nil + ) + withTable(db, tbl, schema, partKeys = Seq("part_1", "part_2")) { + spark.sql( + s"""INSERT INTO `$db`.`$tbl` + |VALUES + | (11L, "2022-04-11", 1), + | (12L, "2022-04-12", 2) AS tab(id, part_1, part_2) + |""".stripMargin + ) + spark.createDataFrame(Seq( + (21L, "2022-04-21", 1), + (22L, "2022-04-22", 2) + )).toDF("id", "part_1", "part_2") + .writeTo(s"$db.$tbl").append + + checkAnswer( + spark.table(s"$db.$tbl").orderBy($"id"), + Row(11L, date("2022-04-11"), 1) :: + Row(12L, date("2022-04-12"), 2) :: + Row(21L, date("2022-04-21"), 1) :: + Row(22L, date("2022-04-22"), 2) :: Nil + ) + + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq( + Row("part_1=2022-04-11/part_2=1"), + Row("part_1=2022-04-12/part_2=2"), + Row("part_1=2022-04-21/part_2=1"), + Row("part_1=2022-04-22/part_2=2") + ) + ) + } + } + + // TODO remove this hack version + test("clickhouse partition toYYYYMMDD(toDate(col))") { + val db = "db_part_toYYYYMMDD_toDate" + val tbl = "tbl_part_toYYYYMMDD_toDate" + autoCleanupTable(db, tbl) { case (db, tbl) => + runClickHouseSQL( + s"""CREATE TABLE IF NOT EXISTS `$db`.`$tbl` ( + | `id` Int64, + | `dt` String + |) ENGINE = MergeTree + |PARTITION BY toYYYYMMDD(toDate(dt)) + |ORDER BY (id) + |""".stripMargin + ) + spark.createDataFrame(Seq( + (1L, "2022-06-06"), + (2L, "2022-06-07") + )).toDF("id", "dt") + .writeTo(s"$db.$tbl").append + checkAnswer( + spark.sql(s"SHOW PARTITIONS $db.$tbl"), + Seq( + Row("dt=20220606"), + Row("dt=20220607") + ) + ) + checkAnswer( + spark.table(s"$db.$tbl").orderBy($"id"), + Seq( + Row(1L, "2022-06-06"), + Row(2L, "2022-06-07") + ) + ) + } + } + + test("clickhouse multi sort columns") { + val db = "db_multi_sort_col" + val tbl = "tbl_multi_sort_col" + val schema = + StructType( + StructField("id", LongType, false) :: + StructField("value", StringType, false) :: + StructField("sort_2", StringType, false) :: + StructField("sort_3", IntegerType, false) :: Nil + ) + withTable(db, tbl, schema, sortKeys = Seq("sort_2", "sort_3")) { + spark.sql( + s"""INSERT INTO `$db`.`$tbl` + |VALUES + | (11L, 'one_one', '1', 1), + | (12L, 'one_two', '1', 2) AS tab(id, value, sort_2, sort_3) + |""".stripMargin + ) + + spark.createDataFrame(Seq( + (21L, "two_one", "2", 1), + (22L, "two_two", "2", 2) + )) + .toDF("id", "value", "sort_2", "sort_3") + .writeTo(s"$db.$tbl").append + + checkAnswer( + spark.table(s"$db.$tbl").orderBy($"id"), + Row(11L, "one_one", "1", 1) :: + Row(12L, "one_two", "1", 2) :: + Row(21L, "two_one", "2", 1) :: + Row(22L, "two_two", "2", 2) :: Nil + ) + } + } + + test("clickhouse truncate table") { + withClickHouseSingleIdTable("db_trunc", "tbl_trunc") { (db, tbl) => + spark.range(10).toDF("id").writeTo(s"$db.$tbl").append + assert(spark.table(s"$db.$tbl").count == 10) + spark.sql(s"TRUNCATE TABLE $db.$tbl") + assert(spark.table(s"$db.$tbl").count == 0) + } + } + + test("clickhouse delete") { + withClickHouseSingleIdTable("db_del", "tbl_db_del") { (db, tbl) => + spark.range(10).toDF("id").writeTo(s"$db.$tbl").append + assert(spark.table(s"$db.$tbl").count == 10) + spark.sql(s"DELETE FROM $db.$tbl WHERE id < 5") + assert(spark.table(s"$db.$tbl").count == 5) + } + } + + test("clickhouse write then read") { + val db = "db_rw" + val tbl = "tbl_rw" + + withSimpleTable(db, tbl, true) { + val tblSchema = spark.table(s"$db.$tbl").schema + assert(tblSchema == StructType( + StructField("id", DataTypes.LongType, false) :: + StructField("value", DataTypes.StringType, true) :: + StructField("create_time", DataTypes.TimestampType, false) :: + StructField("m", DataTypes.IntegerType, false) :: Nil + )) + + checkAnswer( + spark.table(s"$db.$tbl").sort("m"), + Seq( + Row(1L, "1", timestamp("2021-01-01T10:10:10Z"), 1), + Row(2L, "2", timestamp("2022-02-02T10:10:10Z"), 2) + ) + ) + + checkAnswer( + spark.table(s"$db.$tbl").filter($"id" > 1), + Row(2L, "2", timestamp("2022-02-02T10:10:10Z"), 2) :: Nil + ) + + assert(spark.table(s"$db.$tbl").filter($"id" > 1).count === 1) + + // infiniteLoop() + } + } + + test("clickhouse metadata column") { + val db = "db_metadata_col" + val tbl = "tbl_metadata_col" + + withSimpleTable(db, tbl, true) { + checkAnswer( + spark.sql(s"SELECT m, _partition_id FROM $db.$tbl ORDER BY m"), + Seq( + Row(1, "1"), + Row(2, "2") + ) + ) + } + } + + test("push down limit") { + checkAnswer( + spark.sql(s"SELECT zero FROM system.zeros LIMIT 2"), + Seq(Row(0), Row(0)) + ) + } + + test("push down aggregation") { + val db = "db_agg_col" + val tbl = "tbl_agg_col" + + withSimpleTable(db, tbl, true) { + checkAnswer( + spark.sql(s"SELECT COUNT(id) FROM $db.$tbl"), + Seq(Row(2)) + ) + + checkAnswer( + spark.sql(s"SELECT MIN(id) FROM $db.$tbl"), + Seq(Row(1)) + ) + + checkAnswer( + spark.sql(s"SELECT MAX(id) FROM $db.$tbl"), + Seq(Row(2)) + ) + + checkAnswer( + spark.sql(s"SELECT m, COUNT(DISTINCT id) FROM $db.$tbl GROUP BY m"), + Seq( + Row(1, 1), + Row(2, 1) + ) + ) + + checkAnswer( + spark.sql(s"SELECT m, SUM(DISTINCT id) FROM $db.$tbl GROUP BY m"), + Seq( + Row(1, 1), + Row(2, 2) + ) + ) + } + } + + test("create or replace table") { + autoCleanupTable("db_cor", "tbl_cor") { (db, tbl) => + def createOrReplaceTable(): Unit = spark.sql( + s"""CREATE OR REPLACE TABLE `$db`.`$tbl` ( + | id Long NOT NULL + |) USING ClickHouse + |TBLPROPERTIES ( + | engine = 'MergeTree()', + | order_by = 'id', + | settings.index_granularity = 8192 + |) + |""".stripMargin + ) + createOrReplaceTable() + createOrReplaceTable() + } + } + + test("cache table") { + val db = "cache_db" + val tbl = "cache_tbl" + + withSimpleTable(db, tbl, true) { + try { + spark.sql(s"CACHE TABLE $db.$tbl") + val cachedPlan = spark.sql(s"SELECT * FROM $db.$tbl").queryExecution.commandExecuted + .find(node => spark.sharedState.cacheManager.lookupCachedData(node).isDefined) + assert(cachedPlan.isDefined) + } finally + spark.sql(s"UNCACHE TABLE $db.$tbl") + } + } + + test("runtime filter") { + val db = "runtime_db" + val tbl = "runtime_tbl" + + withSimpleTable(db, tbl, true) { + spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=false") + checkAnswer( + spark.sql(s"SELECT id FROM $db.$tbl " + + s"WHERE id IN (" + + s" SELECT id FROM $db.$tbl " + + s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" + + s")"), + Row(1) + ) + + spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=true") + val df = spark.sql(s"SELECT id FROM $db.$tbl " + + s"WHERE id IN (" + + s" SELECT id FROM $db.$tbl " + + s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" + + s")") + checkAnswer(df, Row(1)) + val runtimeFilterExists = df.queryExecution.sparkPlan.exists { + case BatchScanExec(_, _, runtimeFilters, _, table, _) + if table.name() == TableIdentifier(tbl, Some(db)).quotedString + && runtimeFilters.nonEmpty => true + case _ => false + } + assert(runtimeFilterExists) + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseJsonReaderSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseJsonReaderSuite.scala new file mode 100644 index 00000000..c62d5564 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseJsonReaderSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.{ClickHouseCloudMixIn, ClickHouseSingleMixIn} +import org.scalatest.tags.Cloud + +@Cloud +class ClickHouseCloudJsonReaderSuite extends ClickHouseJsonReaderSuite with ClickHouseCloudMixIn + +class ClickHouseSingleJsonReaderSuite extends ClickHouseJsonReaderSuite with ClickHouseSingleMixIn + +/** + * Test suite for ClickHouse JSON Reader. + * Uses JSON format for reading data from ClickHouse (default in SparkClickHouseSingleTest). + * All test cases are inherited from ClickHouseReaderTestBase. + */ +abstract class ClickHouseJsonReaderSuite extends ClickHouseReaderTestBase { + // Uses JSON format (configured in SparkClickHouseSingleTest) + // All tests are inherited from ClickHouseReaderTestBase + // Additional JSON-specific tests can be added here if needed +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseJsonWriterSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseJsonWriterSuite.scala new file mode 100644 index 00000000..3532b140 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseJsonWriterSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.ClickHouseSingleMixIn +import org.apache.spark.SparkConf + +class ClickHouseSingleJsonWriterSuite extends ClickHouseJsonWriterSuite with ClickHouseSingleMixIn + +abstract class ClickHouseJsonWriterSuite extends ClickHouseWriterTestBase { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.clickhouse.write.format", "json") + .set("spark.clickhouse.read.format", "json") + +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseReaderTestBase.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseReaderTestBase.scala new file mode 100644 index 00000000..73e9119f --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseReaderTestBase.scala @@ -0,0 +1,1331 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import org.apache.spark.sql.Row + +/** + * Shared test cases for both JSON and Binary readers. + * Subclasses only need to configure the read format. + * + * Tests are organized by ClickHouse data type with both regular and nullable variants. + * Each type includes comprehensive coverage of edge cases and null handling. + */ +trait ClickHouseReaderTestBase extends SparkClickHouseSingleTest { + + // ============================================================================ + // ArrayType Tests + // ============================================================================ + + test("decode ArrayType - Array of integers") { + withKVTable("test_db", "test_array_int", valueColDef = "Array(Int32)") { + runClickHouseSQL( + """INSERT INTO test_db.test_array_int VALUES + |(1, [1, 2, 3]), + |(2, []), + |(3, [100, 200, 300, 400]) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_array_int ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getSeq[Int](1) == Seq(1, 2, 3)) + assert(result(1).getSeq[Int](1) == Seq()) + assert(result(2).getSeq[Int](1) == Seq(100, 200, 300, 400)) + } + } + test("decode ArrayType - Array of strings") { + withKVTable("test_db", "test_array_string", valueColDef = "Array(String)") { + runClickHouseSQL( + """INSERT INTO test_db.test_array_string VALUES + |(1, ['hello', 'world']), + |(2, []), + |(3, ['a', 'b', 'c']) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_array_string ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getSeq[String](1) == Seq("hello", "world")) + assert(result(1).getSeq[String](1) == Seq()) + assert(result(2).getSeq[String](1) == Seq("a", "b", "c")) + } + } + test("decode ArrayType - Array with nullable elements") { + withKVTable("test_db", "test_array_nullable", valueColDef = "Array(Nullable(Int32))") { + runClickHouseSQL( + """INSERT INTO test_db.test_array_nullable VALUES + |(1, [1, NULL, 3]), + |(2, [NULL, NULL]), + |(3, [100, 200]) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_array_nullable ORDER BY key") + val result = df.collect() + assert(result.length == 3) + // Verify arrays can be read + assert(result(0).getSeq[Any](1) != null) + assert(result(1).getSeq[Any](1) != null) + assert(result(2).getSeq[Any](1) != null) + } + } + test("decode ArrayType - empty arrays") { + withKVTable("test_db", "test_empty_array", valueColDef = "Array(Int32)") { + runClickHouseSQL( + """INSERT INTO test_db.test_empty_array VALUES + |(1, []), + |(2, [1, 2, 3]), + |(3, []) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_empty_array ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getSeq[Int](1).isEmpty) + assert(result(1).getSeq[Int](1) == Seq(1, 2, 3)) + assert(result(2).getSeq[Int](1).isEmpty) + } + } + test("decode ArrayType - Nested arrays") { + withKVTable("test_db", "test_nested_array", valueColDef = "Array(Array(Int32))") { + runClickHouseSQL( + """INSERT INTO test_db.test_nested_array VALUES + |(1, [[1, 2], [3, 4]]), + |(2, [[], [5]]), + |(3, [[10, 20, 30]]) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_nested_array ORDER BY key") + val result = df.collect() + assert(result.length == 3) + // Verify nested arrays can be read + assert(result(0).get(1) != null) + assert(result(1).get(1) != null) + assert(result(2).get(1) != null) + } + } + test("decode BinaryType - FixedString") { + // FixedString is read as String by default in the connector + withKVTable("test_db", "test_fixedstring", valueColDef = "FixedString(5)") { + runClickHouseSQL( + """INSERT INTO test_db.test_fixedstring VALUES + |(1, 'hello'), + |(2, 'world') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_fixedstring ORDER BY key") + val result = df.collect() + assert(result.length == 2) + // FixedString should be readable + assert(result(0).get(1) != null) + assert(result(1).get(1) != null) + } + } + test("decode BinaryType - FixedString nullable with null values") { + withKVTable("test_db", "test_fixedstring_null", valueColDef = "Nullable(FixedString(5))") { + runClickHouseSQL( + """INSERT INTO test_db.test_fixedstring_null VALUES + |(1, 'hello'), + |(2, NULL), + |(3, 'world') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_fixedstring_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).get(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).get(1) != null) + } + } + + // ============================================================================ + // BooleanType Tests + // ============================================================================ + + test("decode BooleanType - true and false values") { + // ClickHouse Bool is stored as UInt8 (0 or 1) + // JSON format reads as Boolean, Binary format reads as Short + withKVTable("test_db", "test_bool", valueColDef = "Bool") { + runClickHouseSQL( + """INSERT INTO test_db.test_bool VALUES + |(1, true), + |(2, false), + |(3, 1), + |(4, 0) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_bool ORDER BY key") + val result = df.collect() + assert(result.length == 4) + // Check the value - handle both Boolean (JSON) and Short (Binary) formats + val v0 = result(0).get(1) + val v1 = result(1).get(1) + v0 match { + case b: Boolean => + assert(b == true) + assert(result(1).getBoolean(1) == false) + assert(result(2).getBoolean(1) == true) + assert(result(3).getBoolean(1) == false) + case s: Short => + assert(s == 1) + assert(result(1).getShort(1) == 0) + assert(result(2).getShort(1) == 1) + assert(result(3).getShort(1) == 0) + case _ => fail(s"Unexpected type: ${v0.getClass}") + } + } + } + test("decode BooleanType - nullable with null values") { + withKVTable("test_db", "test_bool_null", valueColDef = "Nullable(Bool)") { + runClickHouseSQL( + """INSERT INTO test_db.test_bool_null VALUES + |(1, true), + |(2, NULL), + |(3, false) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_bool_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(1).isNullAt(1)) + // Check the value - handle both Boolean (JSON) and Short (Binary) formats + val v0 = result(0).get(1) + v0 match { + case b: Boolean => + assert(b == true) + assert(result(2).getBoolean(1) == false) + case s: Short => + assert(s == 1) + assert(result(2).getShort(1) == 0) + case _ => fail(s"Unexpected type: ${v0.getClass}") + } + } + } + + // ============================================================================ + // ByteType Tests + // ============================================================================ + + test("decode ByteType - min and max values") { + withKVTable("test_db", "test_byte", valueColDef = "Int8") { + runClickHouseSQL( + """INSERT INTO test_db.test_byte VALUES + |(1, -128), + |(2, 0), + |(3, 127) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_byte ORDER BY key") + checkAnswer( + df, + Row(1, -128.toByte) :: Row(2, 0.toByte) :: Row(3, 127.toByte) :: Nil + ) + } + } + test("decode ByteType - nullable with null values") { + withKVTable("test_db", "test_byte_null", valueColDef = "Nullable(Int8)") { + runClickHouseSQL( + """INSERT INTO test_db.test_byte_null VALUES + |(1, -128), + |(2, NULL), + |(3, 127) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_byte_null ORDER BY key") + checkAnswer( + df, + Row(1, -128.toByte) :: Row(2, null) :: Row(3, 127.toByte) :: Nil + ) + } + } + test("decode DateTime32 - 32-bit timestamp") { + withKVTable("test_db", "test_datetime32", valueColDef = "DateTime32") { + runClickHouseSQL( + """INSERT INTO test_db.test_datetime32 VALUES + |(1, '2024-01-01 12:00:00'), + |(2, '2024-06-15 18:30:45') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_datetime32 ORDER BY key") + val result = df.collect() + assert(result.length == 2) + assert(result(0).getTimestamp(1) != null) + assert(result(1).getTimestamp(1) != null) + } + } + test("decode DateTime32 - nullable with null values") { + withKVTable("test_db", "test_datetime32_null", valueColDef = "Nullable(DateTime32)") { + runClickHouseSQL( + """INSERT INTO test_db.test_datetime32_null VALUES + |(1, '2024-01-01 12:00:00'), + |(2, NULL), + |(3, '2024-06-15 18:30:45') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_datetime32_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getTimestamp(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getTimestamp(1) != null) + } + } + test("decode DateType - Date") { + withKVTable("test_db", "test_date", valueColDef = "Date") { + runClickHouseSQL( + """INSERT INTO test_db.test_date VALUES + |(1, '2024-01-01'), + |(2, '2024-06-15'), + |(3, '2024-12-31') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_date ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDate(1) != null) + assert(result(1).getDate(1) != null) + assert(result(2).getDate(1) != null) + } + } + test("decode DateType - Date32") { + withKVTable("test_db", "test_date32", valueColDef = "Date32") { + runClickHouseSQL( + """INSERT INTO test_db.test_date32 VALUES + |(1, '1900-01-01'), + |(2, '2024-06-15'), + |(3, '2100-12-31') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_date32 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDate(1) != null) + assert(result(1).getDate(1) != null) + assert(result(2).getDate(1) != null) + } + } + test("decode DateType - Date32 nullable with null values") { + withKVTable("test_db", "test_date32_null", valueColDef = "Nullable(Date32)") { + runClickHouseSQL( + """INSERT INTO test_db.test_date32_null VALUES + |(1, '1900-01-01'), + |(2, NULL), + |(3, '2100-12-31') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_date32_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDate(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getDate(1) != null) + } + } + test("decode DateType - nullable with null values") { + withKVTable("test_db", "test_date_null", valueColDef = "Nullable(Date)") { + runClickHouseSQL( + """INSERT INTO test_db.test_date_null VALUES + |(1, '2024-01-01'), + |(2, NULL), + |(3, '2024-12-31') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_date_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDate(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getDate(1) != null) + } + } + test("decode DecimalType - Decimal128") { + // Decimal128(20) means scale=20, max precision=38 total digits + // Use values with max 18 digits before decimal to stay within 38 total + withKVTable("test_db", "test_decimal128", valueColDef = "Decimal128(20)") { + runClickHouseSQL( + """INSERT INTO test_db.test_decimal128 VALUES + |(1, 123456789012345.12345678901234567890), + |(2, -999999999999999.99999999999999999999), + |(3, 0.00000000000000000001) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_decimal128 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + // Decimal128(20) means 20 decimal places, total precision up to 38 digits + assert(math.abs(result(0).getDecimal(1).doubleValue() - 123456789012345.12345678901234567890) < 0.01) + assert(math.abs(result(1).getDecimal(1).doubleValue() - -999999999999999.99999999999999999999) < 0.01) + assert(result(2).getDecimal(1) != null) + } + } + test("decode DecimalType - Decimal128 nullable with null values") { + withKVTable("test_db", "test_decimal128_null", valueColDef = "Nullable(Decimal128(20))") { + runClickHouseSQL( + """INSERT INTO test_db.test_decimal128_null VALUES + |(1, 123456789012345.12345678901234567890), + |(2, NULL), + |(3, 0.00000000000000000001) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_decimal128_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getDecimal(1) != null) + } + } + test("decode DecimalType - Decimal32") { + withKVTable("test_db", "test_decimal32", valueColDef = "Decimal32(4)") { + runClickHouseSQL( + """INSERT INTO test_db.test_decimal32 VALUES + |(1, 12345.6789), + |(2, -9999.9999), + |(3, 0.0001) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_decimal32 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1).doubleValue() == 12345.6789) + assert(result(1).getDecimal(1).doubleValue() == -9999.9999) + assert(result(2).getDecimal(1).doubleValue() == 0.0001) + } + } + test("decode DecimalType - Decimal32 nullable with null values") { + withKVTable("test_db", "test_decimal32_null", valueColDef = "Nullable(Decimal32(4))") { + runClickHouseSQL( + """INSERT INTO test_db.test_decimal32_null VALUES + |(1, 12345.6789), + |(2, NULL), + |(3, 0.0001) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_decimal32_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getDecimal(1) != null) + } + } + test("decode DecimalType - Decimal64") { + // Decimal64(10) means scale=10, max precision=18 total digits + // Use values with max 8 digits before decimal to stay within 18 total + withKVTable("test_db", "test_decimal64", valueColDef = "Decimal64(10)") { + runClickHouseSQL( + """INSERT INTO test_db.test_decimal64 VALUES + |(1, 1234567.0123456789), + |(2, -9999999.9999999999), + |(3, 0.0000000001) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_decimal64 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(math.abs(result(0).getDecimal(1).doubleValue() - 1234567.0123456789) < 0.0001) + assert(math.abs(result(1).getDecimal(1).doubleValue() - -9999999.9999999999) < 0.0001) + assert(math.abs(result(2).getDecimal(1).doubleValue() - 0.0000000001) < 0.0000000001) + } + } + test("decode DecimalType - Decimal64 nullable with null values") { + withKVTable("test_db", "test_decimal64_null", valueColDef = "Nullable(Decimal64(10))") { + runClickHouseSQL( + """INSERT INTO test_db.test_decimal64_null VALUES + |(1, 1234567.0123456789), + |(2, NULL), + |(3, 0.0000000001) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_decimal64_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getDecimal(1) != null) + } + } + test("decode DoubleType - nullable with null values") { + withKVTable("test_db", "test_double_null", valueColDef = "Nullable(Float64)") { + runClickHouseSQL( + """INSERT INTO test_db.test_double_null VALUES + |(1, 1.23), + |(2, NULL), + |(3, -4.56) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_double_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(math.abs(result(0).getDouble(1) - 1.23) < 0.0001) + assert(result(1).isNullAt(1)) + assert(math.abs(result(2).getDouble(1) - -4.56) < 0.0001) + } + } + test("decode DoubleType - regular values") { + withKVTable("test_db", "test_double", valueColDef = "Float64") { + runClickHouseSQL( + """INSERT INTO test_db.test_double VALUES + |(1, -3.141592653589793), + |(2, 0.0), + |(3, 3.141592653589793) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_double ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(math.abs(result(0).getDouble(1) - -3.141592653589793) < 0.000001) + assert(result(1).getDouble(1) == 0.0) + assert(math.abs(result(2).getDouble(1) - 3.141592653589793) < 0.000001) + } + } + test("decode Enum16 - large enum") { + withKVTable("test_db", "test_enum16", valueColDef = "Enum16('small' = 1, 'medium' = 100, 'large' = 1000)") { + runClickHouseSQL( + """INSERT INTO test_db.test_enum16 VALUES + |(1, 'small'), + |(2, 'medium'), + |(3, 'large') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_enum16 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) == "small") + assert(result(1).getString(1) == "medium") + assert(result(2).getString(1) == "large") + } + } + test("decode Enum16 - nullable with null values") { + withKVTable( + "test_db", + "test_enum16_null", + valueColDef = "Nullable(Enum16('small' = 1, 'medium' = 100, 'large' = 1000))" + ) { + runClickHouseSQL( + """INSERT INTO test_db.test_enum16_null VALUES + |(1, 'small'), + |(2, NULL), + |(3, 'large') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_enum16_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) == "small") + assert(result(1).isNullAt(1)) + assert(result(2).getString(1) == "large") + } + } + test("decode Enum8 - nullable with null values") { + withKVTable("test_db", "test_enum8_null", valueColDef = "Nullable(Enum8('red' = 1, 'green' = 2, 'blue' = 3))") { + runClickHouseSQL( + """INSERT INTO test_db.test_enum8_null VALUES + |(1, 'red'), + |(2, NULL), + |(3, 'blue') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_enum8_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) == "red") + assert(result(1).isNullAt(1)) + assert(result(2).getString(1) == "blue") + } + } + test("decode Enum8 - small enum") { + withKVTable("test_db", "test_enum8", valueColDef = "Enum8('red' = 1, 'green' = 2, 'blue' = 3)") { + runClickHouseSQL( + """INSERT INTO test_db.test_enum8 VALUES + |(1, 'red'), + |(2, 'green'), + |(3, 'blue') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_enum8 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) == "red") + assert(result(1).getString(1) == "green") + assert(result(2).getString(1) == "blue") + } + } + test("decode FloatType - nullable with null values") { + withKVTable("test_db", "test_float_null", valueColDef = "Nullable(Float32)") { + runClickHouseSQL( + """INSERT INTO test_db.test_float_null VALUES + |(1, 1.5), + |(2, NULL), + |(3, -2.5) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_float_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(math.abs(result(0).getFloat(1) - 1.5f) < 0.01f) + assert(result(1).isNullAt(1)) + assert(math.abs(result(2).getFloat(1) - -2.5f) < 0.01f) + } + } + test("decode FloatType - regular values") { + withKVTable("test_db", "test_float", valueColDef = "Float32") { + runClickHouseSQL( + """INSERT INTO test_db.test_float VALUES + |(1, -3.14), + |(2, 0.0), + |(3, 3.14) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_float ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(math.abs(result(0).getFloat(1) - -3.14f) < 0.01f) + assert(result(1).getFloat(1) == 0.0f) + assert(math.abs(result(2).getFloat(1) - 3.14f) < 0.01f) + } + } + test("decode Int128 - large integers as Decimal") { + withKVTable("test_db", "test_int128", valueColDef = "Int128") { + runClickHouseSQL( + """INSERT INTO test_db.test_int128 VALUES + |(1, 0), + |(2, 123456789012345678901234567890), + |(3, -123456789012345678901234567890) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_int128 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1).toBigInteger.longValue == 0L) + assert(result(1).getDecimal(1) != null) + assert(result(2).getDecimal(1) != null) + } + } + test("decode Int128 - nullable with null values") { + withKVTable("test_db", "test_int128_null", valueColDef = "Nullable(Int128)") { + runClickHouseSQL( + """INSERT INTO test_db.test_int128_null VALUES + |(1, 0), + |(2, NULL), + |(3, -123456789012345678901234567890) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_int128_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1).toBigInteger.longValue == 0L) + assert(result(1).isNullAt(1)) + assert(result(2).getDecimal(1) != null) + } + } + test("decode Int256 - nullable with null values") { + withKVTable("test_db", "test_int256_null", valueColDef = "Nullable(Int256)") { + runClickHouseSQL( + """INSERT INTO test_db.test_int256_null VALUES + |(1, 0), + |(2, NULL), + |(3, 12345678901234567890123456789012345678) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_int256_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1).toBigInteger.longValue == 0L) + assert(result(1).isNullAt(1)) + assert(result(2).getDecimal(1) != null) + } + } + test("decode Int256 - very large integers as Decimal") { + withKVTable("test_db", "test_int256", valueColDef = "Int256") { + runClickHouseSQL( + """INSERT INTO test_db.test_int256 VALUES + |(1, 0), + |(2, 12345678901234567890123456789012345678) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_int256 ORDER BY key") + val result = df.collect() + assert(result.length == 2) + assert(result(0).getDecimal(1).toBigInteger.longValue == 0L) + assert(result(1).getDecimal(1) != null) + } + } + test("decode IntegerType - min and max values") { + withKVTable("test_db", "test_int", valueColDef = "Int32") { + runClickHouseSQL( + """INSERT INTO test_db.test_int VALUES + |(1, -2147483648), + |(2, 0), + |(3, 2147483647) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_int ORDER BY key") + checkAnswer( + df, + Row(1, -2147483648) :: Row(2, 0) :: Row(3, 2147483647) :: Nil + ) + } + } + test("decode IntegerType - nullable with null values") { + withKVTable("test_db", "test_int_null", valueColDef = "Nullable(Int32)") { + runClickHouseSQL( + """INSERT INTO test_db.test_int_null VALUES + |(1, -2147483648), + |(2, NULL), + |(3, 2147483647) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_int_null ORDER BY key") + checkAnswer( + df, + Row(1, -2147483648) :: Row(2, null) :: Row(3, 2147483647) :: Nil + ) + } + } + test("decode IPv4 - IP addresses") { + withKVTable("test_db", "test_ipv4", valueColDef = "IPv4") { + runClickHouseSQL( + """INSERT INTO test_db.test_ipv4 VALUES + |(1, '127.0.0.1'), + |(2, '192.168.1.1'), + |(3, '8.8.8.8') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_ipv4 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) == "127.0.0.1") + assert(result(1).getString(1) == "192.168.1.1") + assert(result(2).getString(1) == "8.8.8.8") + } + } + test("decode IPv4 - nullable with null values") { + withKVTable("test_db", "test_ipv4_null", valueColDef = "Nullable(IPv4)") { + runClickHouseSQL( + """INSERT INTO test_db.test_ipv4_null VALUES + |(1, '127.0.0.1'), + |(2, NULL), + |(3, '8.8.8.8') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_ipv4_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) == "127.0.0.1") + assert(result(1).isNullAt(1)) + assert(result(2).getString(1) == "8.8.8.8") + } + } + test("decode IPv6 - IPv6 addresses") { + withKVTable("test_db", "test_ipv6", valueColDef = "IPv6") { + runClickHouseSQL( + """INSERT INTO test_db.test_ipv6 VALUES + |(1, '::1'), + |(2, '2001:0db8:85a3:0000:0000:8a2e:0370:7334'), + |(3, 'fe80::1') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_ipv6 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) != null) + assert(result(1).getString(1) != null) + assert(result(2).getString(1) != null) + } + } + test("decode IPv6 - nullable with null values") { + withKVTable("test_db", "test_ipv6_null", valueColDef = "Nullable(IPv6)") { + runClickHouseSQL( + """INSERT INTO test_db.test_ipv6_null VALUES + |(1, '::1'), + |(2, NULL), + |(3, 'fe80::1') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_ipv6_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getString(1) != null) + } + } + test("decode JSON - nullable with null values") { + withKVTable("test_db", "test_json_null", valueColDef = "Nullable(String)") { + runClickHouseSQL( + """INSERT INTO test_db.test_json_null VALUES + |(1, '{"name": "Alice", "age": 30}'), + |(2, NULL), + |(3, '{"name": "Charlie", "age": 35}') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_json_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1).contains("Alice")) + assert(result(1).isNullAt(1)) + assert(result(2).getString(1).contains("Charlie")) + } + } + test("decode JSON - semi-structured data") { + withKVTable("test_db", "test_json", valueColDef = "String") { + runClickHouseSQL( + """INSERT INTO test_db.test_json VALUES + |(1, '{"name": "Alice", "age": 30}'), + |(2, '{"name": "Bob", "age": 25}'), + |(3, '{"name": "Charlie", "age": 35}') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_json ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1).contains("Alice")) + assert(result(1).getString(1).contains("Bob")) + assert(result(2).getString(1).contains("Charlie")) + } + } + test("decode LongType - min and max values") { + withKVTable("test_db", "test_long", valueColDef = "Int64") { + runClickHouseSQL( + """INSERT INTO test_db.test_long VALUES + |(1, -9223372036854775808), + |(2, 0), + |(3, 9223372036854775807) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_long ORDER BY key") + checkAnswer( + df, + Row(1, -9223372036854775808L) :: Row(2, 0L) :: Row(3, 9223372036854775807L) :: Nil + ) + } + } + test("decode LongType - nullable with null values") { + withKVTable("test_db", "test_long_null", valueColDef = "Nullable(Int64)") { + runClickHouseSQL( + """INSERT INTO test_db.test_long_null VALUES + |(1, -9223372036854775808), + |(2, NULL), + |(3, 9223372036854775807) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_long_null ORDER BY key") + checkAnswer( + df, + Row(1, -9223372036854775808L) :: Row(2, null) :: Row(3, 9223372036854775807L) :: Nil + ) + } + } + test("decode LongType - UInt32 nullable with null values") { + withKVTable("test_db", "test_uint32_null", valueColDef = "Nullable(UInt32)") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint32_null VALUES + |(1, 0), + |(2, NULL), + |(3, 4294967295) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint32_null ORDER BY key") + checkAnswer( + df, + Row(1, 0L) :: Row(2, null) :: Row(3, 4294967295L) :: Nil + ) + } + } + test("decode LongType - UInt32 values") { + withKVTable("test_db", "test_uint32", valueColDef = "UInt32") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint32 VALUES + |(1, 0), + |(2, 2147483648), + |(3, 4294967295) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint32 ORDER BY key") + checkAnswer( + df, + Row(1, 0L) :: Row(2, 2147483648L) :: Row(3, 4294967295L) :: Nil + ) + } + } + test("decode MapType - Map of String to Int") { + withKVTable("test_db", "test_map", valueColDef = "Map(String, Int32)") { + runClickHouseSQL( + """INSERT INTO test_db.test_map VALUES + |(1, {'a': 1, 'b': 2}), + |(2, {}), + |(3, {'x': 100, 'y': 200, 'z': 300}) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_map ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getMap[String, Int](1) == Map("a" -> 1, "b" -> 2)) + assert(result(1).getMap[String, Int](1) == Map()) + assert(result(2).getMap[String, Int](1) == Map("x" -> 100, "y" -> 200, "z" -> 300)) + } + } + test("decode MapType - Map with nullable values") { + withKVTable("test_db", "test_map_nullable", valueColDef = "Map(String, Nullable(Int32))") { + runClickHouseSQL( + """INSERT INTO test_db.test_map_nullable VALUES + |(1, {'a': 1, 'b': NULL}), + |(2, {'x': NULL}), + |(3, {'p': 100, 'q': 200}) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_map_nullable ORDER BY key") + val result = df.collect() + assert(result.length == 3) + // Verify maps can be read + assert(result(0).getMap[String, Any](1) != null) + assert(result(1).getMap[String, Any](1) != null) + assert(result(2).getMap[String, Any](1) != null) + } + } + test("decode ShortType - min and max values") { + withKVTable("test_db", "test_short", valueColDef = "Int16") { + runClickHouseSQL( + """INSERT INTO test_db.test_short VALUES + |(1, -32768), + |(2, 0), + |(3, 32767) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_short ORDER BY key") + checkAnswer( + df, + Row(1, -32768.toShort) :: Row(2, 0.toShort) :: Row(3, 32767.toShort) :: Nil + ) + } + } + test("decode ShortType - nullable with null values") { + withKVTable("test_db", "test_short_null", valueColDef = "Nullable(Int16)") { + runClickHouseSQL( + """INSERT INTO test_db.test_short_null VALUES + |(1, -32768), + |(2, NULL), + |(3, 32767) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_short_null ORDER BY key") + checkAnswer( + df, + Row(1, -32768.toShort) :: Row(2, null) :: Row(3, 32767.toShort) :: Nil + ) + } + } + test("decode ShortType - UInt8 nullable with null values") { + withKVTable("test_db", "test_uint8_null", valueColDef = "Nullable(UInt8)") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint8_null VALUES + |(1, 0), + |(2, NULL), + |(3, 255) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint8_null ORDER BY key") + checkAnswer( + df, + Row(1, 0.toShort) :: Row(2, null) :: Row(3, 255.toShort) :: Nil + ) + } + } + test("decode ShortType - UInt8 values") { + withKVTable("test_db", "test_uint8", valueColDef = "UInt8") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint8 VALUES + |(1, 0), + |(2, 128), + |(3, 255) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint8 ORDER BY key") + checkAnswer( + df, + Row(1, 0.toShort) :: Row(2, 128.toShort) :: Row(3, 255.toShort) :: Nil + ) + } + } + test("decode StringType - empty strings") { + withKVTable("test_db", "test_empty_string", valueColDef = "String") { + runClickHouseSQL( + """INSERT INTO test_db.test_empty_string VALUES + |(1, ''), + |(2, 'not empty'), + |(3, '') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_empty_string ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) == "") + assert(result(1).getString(1) == "not empty") + assert(result(2).getString(1) == "") + } + } + test("decode StringType - nullable with null values") { + withKVTable("test_db", "test_string_null", valueColDef = "Nullable(String)") { + runClickHouseSQL( + """INSERT INTO test_db.test_string_null VALUES + |(1, 'hello'), + |(2, NULL), + |(3, 'world') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_string_null ORDER BY key") + checkAnswer( + df, + Row(1, "hello") :: Row(2, null) :: Row(3, "world") :: Nil + ) + } + } + test("decode StringType - regular strings") { + withKVTable("test_db", "test_string", valueColDef = "String") { + runClickHouseSQL( + """INSERT INTO test_db.test_string VALUES + |(1, 'hello'), + |(2, ''), + |(3, 'world with spaces'), + |(4, 'special chars: !@#$%^&*()') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_string ORDER BY key") + checkAnswer( + df, + Row(1, "hello") :: Row(2, "") :: Row(3, "world with spaces") :: Row(4, "special chars: !@#$%^&*()") :: Nil + ) + } + } + test("decode StringType - UUID") { + withKVTable("test_db", "test_uuid", valueColDef = "UUID") { + runClickHouseSQL( + """INSERT INTO test_db.test_uuid VALUES + |(1, '550e8400-e29b-41d4-a716-446655440000'), + |(2, '6ba7b810-9dad-11d1-80b4-00c04fd430c8') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uuid ORDER BY key") + val result = df.collect() + assert(result.length == 2) + assert(result(0).getString(1) == "550e8400-e29b-41d4-a716-446655440000") + assert(result(1).getString(1) == "6ba7b810-9dad-11d1-80b4-00c04fd430c8") + } + } + test("decode StringType - UUID nullable with null values") { + withKVTable("test_db", "test_uuid_null", valueColDef = "Nullable(UUID)") { + runClickHouseSQL( + """INSERT INTO test_db.test_uuid_null VALUES + |(1, '550e8400-e29b-41d4-a716-446655440000'), + |(2, NULL), + |(3, '6ba7b810-9dad-11d1-80b4-00c04fd430c8') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uuid_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getString(1) == "550e8400-e29b-41d4-a716-446655440000") + assert(result(1).isNullAt(1)) + assert(result(2).getString(1) == "6ba7b810-9dad-11d1-80b4-00c04fd430c8") + } + } + test("decode StringType - very long strings") { + val longString = "a" * 10000 + withKVTable("test_db", "test_long_string", valueColDef = "String") { + runClickHouseSQL( + s"""INSERT INTO test_db.test_long_string VALUES + |(1, '$longString') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_long_string ORDER BY key") + val result = df.collect() + assert(result.length == 1) + assert(result(0).getString(1).length == 10000) + } + } + test("decode TimestampType - DateTime") { + withKVTable("test_db", "test_datetime", valueColDef = "DateTime") { + runClickHouseSQL( + """INSERT INTO test_db.test_datetime VALUES + |(1, '2024-01-01 00:00:00'), + |(2, '2024-06-15 12:30:45'), + |(3, '2024-12-31 23:59:59') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_datetime ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getTimestamp(1) != null) + assert(result(1).getTimestamp(1) != null) + assert(result(2).getTimestamp(1) != null) + } + } + test("decode TimestampType - DateTime64") { + withKVTable("test_db", "test_datetime64", valueColDef = "DateTime64(3)") { + runClickHouseSQL( + """INSERT INTO test_db.test_datetime64 VALUES + |(1, '2024-01-01 00:00:00.123'), + |(2, '2024-06-15 12:30:45.456'), + |(3, '2024-12-31 23:59:59.999') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_datetime64 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getTimestamp(1) != null) + assert(result(1).getTimestamp(1) != null) + assert(result(2).getTimestamp(1) != null) + } + } + test("decode TimestampType - DateTime64 nullable with null values") { + withKVTable("test_db", "test_datetime64_null", valueColDef = "Nullable(DateTime64(3))") { + runClickHouseSQL( + """INSERT INTO test_db.test_datetime64_null VALUES + |(1, '2024-01-01 00:00:00.123'), + |(2, NULL), + |(3, '2024-12-31 23:59:59.999') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_datetime64_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getTimestamp(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getTimestamp(1) != null) + } + } + test("decode TimestampType - nullable with null values") { + withKVTable("test_db", "test_datetime_null", valueColDef = "Nullable(DateTime)") { + runClickHouseSQL( + """INSERT INTO test_db.test_datetime_null VALUES + |(1, '2024-01-01 00:00:00'), + |(2, NULL), + |(3, '2024-12-31 23:59:59') + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_datetime_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getTimestamp(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getTimestamp(1) != null) + } + } + test("decode UInt128 - large unsigned integers as Decimal") { + withKVTable("test_db", "test_uint128", valueColDef = "UInt128") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint128 VALUES + |(1, 0), + |(2, 123456789012345678901234567890) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint128 ORDER BY key") + val result = df.collect() + assert(result.length == 2) + assert(result(0).getDecimal(1).toBigInteger.longValue == 0L) + assert(result(1).getDecimal(1) != null) + } + } + test("decode UInt128 - nullable with null values") { + withKVTable("test_db", "test_uint128_null", valueColDef = "Nullable(UInt128)") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint128_null VALUES + |(1, 0), + |(2, NULL), + |(3, 123456789012345678901234567890) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint128_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1).toBigInteger.longValue == 0L) + assert(result(1).isNullAt(1)) + assert(result(2).getDecimal(1) != null) + } + } + test("decode UInt16 - nullable with null values") { + withKVTable("test_db", "test_uint16_null", valueColDef = "Nullable(UInt16)") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint16_null VALUES + |(1, 0), + |(2, NULL), + |(3, 65535) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint16_null ORDER BY key") + checkAnswer( + df, + Row(1, 0) :: Row(2, null) :: Row(3, 65535) :: Nil + ) + } + } + test("decode UInt16 - unsigned 16-bit integers") { + withKVTable("test_db", "test_uint16", valueColDef = "UInt16") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint16 VALUES + |(1, 0), + |(2, 32768), + |(3, 65535) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint16 ORDER BY key") + checkAnswer( + df, + Row(1, 0) :: Row(2, 32768) :: Row(3, 65535) :: Nil + ) + } + } + test("decode UInt256 - nullable with null values") { + withKVTable("test_db", "test_uint256_null", valueColDef = "Nullable(UInt256)") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint256_null VALUES + |(1, 0), + |(2, NULL), + |(3, 12345678901234567890123456789012345678) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint256_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getDecimal(1).toBigInteger.longValue == 0L) + assert(result(1).isNullAt(1)) + assert(result(2).getDecimal(1) != null) + } + } + test("decode UInt256 - very large unsigned integers as Decimal") { + withKVTable("test_db", "test_uint256", valueColDef = "UInt256") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint256 VALUES + |(1, 0), + |(2, 12345678901234567890123456789012345678) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint256 ORDER BY key") + val result = df.collect() + assert(result.length == 2) + assert(result(0).getDecimal(1).toBigInteger.longValue == 0L) + assert(result(1).getDecimal(1) != null) + } + } + test("decode UInt64 - nullable with null values") { + withKVTable("test_db", "test_uint64_null", valueColDef = "Nullable(UInt64)") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint64_null VALUES + |(1, 0), + |(2, NULL), + |(3, 9223372036854775807) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint64_null ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getLong(1) == 0L) + assert(result(1).isNullAt(1)) + assert(result(2).getLong(1) == 9223372036854775807L) + } + } + test("decode UInt64 - unsigned 64-bit integers") { + withKVTable("test_db", "test_uint64", valueColDef = "UInt64") { + runClickHouseSQL( + """INSERT INTO test_db.test_uint64 VALUES + |(1, 0), + |(2, 1234567890), + |(3, 9223372036854775807) + |""".stripMargin + ) + + val df = spark.sql("SELECT key, value FROM test_db.test_uint64 ORDER BY key") + val result = df.collect() + assert(result.length == 3) + assert(result(0).getLong(1) == 0L) + assert(result(1).getLong(1) == 1234567890L) + // Max value that fits in signed Long + assert(result(2).getLong(1) == 9223372036854775807L) + } + } + +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseTableDDLSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseTableDDLSuite.scala new file mode 100644 index 00000000..f3918098 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseTableDDLSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.{ClickHouseCloudMixIn, ClickHouseSingleMixIn} +import org.apache.spark.sql.Row +import org.scalatest.tags.Cloud + +@Cloud +class ClickHouseCloudTableDDLSuite extends ClickHouseTableDDLSuite with ClickHouseCloudMixIn + +class ClickHouseSingleTableDDLSuite extends ClickHouseTableDDLSuite with ClickHouseSingleMixIn + +abstract class ClickHouseTableDDLSuite extends SparkClickHouseSingleTest { + + import testImplicits._ + + test("clickhouse command runner") { + withTable("default.abc") { + runClickHouseSQL("CREATE TABLE default.abc(a UInt8) ENGINE=Memory()") + checkAnswer( + spark.sql("""DESC default.abc""").select($"col_name", $"data_type").limit(1), + Row("a", "smallint") :: Nil + ) + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseWriterTestBase.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseWriterTestBase.scala new file mode 100644 index 00000000..28267dc2 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseWriterTestBase.scala @@ -0,0 +1,758 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types._ + +/** + * Shared test cases for both JSON and Binary writers. + * Subclasses only need to configure the write format. + */ +trait ClickHouseWriterTestBase extends SparkClickHouseSingleTest { + + test("write ArrayType - array of integers") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", ArrayType(IntegerType, containsNull = false), nullable = false) + )) + + withTable("test_db", "test_write_array_int", schema) { + val data = Seq( + Row(1, Seq(1, 2, 3)), + Row(2, Seq(10, 20, 30)), + Row(3, Seq(100)) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_array_int") + + val result = spark.table("test_db.test_write_array_int").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getSeq[Int](1) == Seq(1, 2, 3)) + assert(result(1).getSeq[Int](1) == Seq(10, 20, 30)) + assert(result(2).getSeq[Int](1) == Seq(100)) + } + } + + test("write ArrayType - empty arrays") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", ArrayType(IntegerType, containsNull = false), nullable = false) + )) + + withTable("test_db", "test_write_empty_array", schema) { + val data = Seq( + Row(1, Seq()), + Row(2, Seq(1, 2, 3)), + Row(3, Seq()) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_empty_array") + + val result = spark.table("test_db.test_write_empty_array").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getSeq[Int](1).isEmpty) + assert(result(1).getSeq[Int](1) == Seq(1, 2, 3)) + assert(result(2).getSeq[Int](1).isEmpty) + } + } + + test("write ArrayType - nested arrays") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField( + "value", + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = false), + nullable = false + ) + )) + + withTable("test_db", "test_write_nested_array", schema) { + val data = Seq( + Row(1, Seq(Seq(1, 2), Seq(3, 4))), + Row(2, Seq(Seq(10, 20, 30))), + Row(3, Seq(Seq(), Seq(100))) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_nested_array") + + val result = spark.table("test_db.test_write_nested_array").orderBy("id").collect() + assert(result.length == 3) + // Convert to List for Scala 2.12/2.13 compatibility + val row0 = result(0).getAs[scala.collection.Seq[scala.collection.Seq[Int]]](1).map(_.toList).toList + val row1 = result(1).getAs[scala.collection.Seq[scala.collection.Seq[Int]]](1).map(_.toList).toList + val row2 = result(2).getAs[scala.collection.Seq[scala.collection.Seq[Int]]](1).map(_.toList).toList + assert(row0 == Seq(Seq(1, 2), Seq(3, 4))) + assert(row1 == Seq(Seq(10, 20, 30))) + assert(row2(0).isEmpty) + assert(row2(1) == Seq(100)) + } + } + + test("write ArrayType - with nullable elements") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", ArrayType(IntegerType, containsNull = true), nullable = false) + )) + + withTable("test_db", "test_write_array_nullable", schema) { + val data = Seq( + Row(1, Seq(1, null, 3)), + Row(2, Seq(null, null)), + Row(3, Seq(10, 20, 30)) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_array_nullable") + + val result = spark.table("test_db.test_write_array_nullable").orderBy("id").collect() + assert(result.length == 3) + val arr1 = result(0).getSeq[Any](1) + assert(arr1.length == 3) + assert(arr1(0) == 1) + assert(arr1(1) == null) + assert(arr1(2) == 3) + } + } + + test("write BooleanType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", BooleanType, nullable = true) + )) + + withTable("test_db", "test_write_bool_null", schema) { + val data = Seq( + Row(1, true), + Row(2, null), + Row(3, false) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_bool_null") + + val result = spark.table("test_db.test_write_bool_null").orderBy("id").collect() + assert(result.length == 3) + // Boolean is stored as UInt8 in ClickHouse, reads back as Short + assert(result(0).getShort(1) == 1) + assert(result(1).isNullAt(1)) + assert(result(2).getShort(1) == 0) + } + } + + // NOTE: ClickHouse stores Boolean as UInt8, so it reads back as Short (0 or 1) + test("write BooleanType - true and false values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", BooleanType, nullable = false) + )) + + withTable("test_db", "test_write_bool", schema) { + val data = Seq( + Row(1, true), + Row(2, false) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_bool") + + val result = spark.table("test_db.test_write_bool").orderBy("id").collect() + assert(result.length == 2) + // Boolean is stored as UInt8 in ClickHouse, reads back as Short + assert(result(0).getShort(1) == 1) + assert(result(1).getShort(1) == 0) + } + } + + test("write ByteType - min and max values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", ByteType, nullable = false) + )) + + withTable("test_db", "test_write_byte", schema) { + val data = Seq( + Row(1, Byte.MinValue), + Row(2, 0.toByte), + Row(3, Byte.MaxValue) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_byte") + + val result = spark.table("test_db.test_write_byte").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getByte(1) == Byte.MinValue) + assert(result(1).getByte(1) == 0.toByte) + assert(result(2).getByte(1) == Byte.MaxValue) + } + } + + test("write ByteType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", ByteType, nullable = true) + )) + + withTable("test_db", "test_write_byte_null", schema) { + val data = Seq( + Row(1, Byte.MinValue), + Row(2, null), + Row(3, Byte.MaxValue) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_byte_null") + + val result = spark.table("test_db.test_write_byte_null").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getByte(1) == Byte.MinValue) + assert(result(1).isNullAt(1)) + assert(result(2).getByte(1) == Byte.MaxValue) + } + } + + test("write DateType - dates") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", DateType, nullable = false) + )) + + withTable("test_db", "test_write_date", schema) { + val data = Seq( + Row(1, java.sql.Date.valueOf("2024-01-01")), + Row(2, java.sql.Date.valueOf("2024-06-15")), + Row(3, java.sql.Date.valueOf("2024-12-31")) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_date") + + val result = spark.table("test_db.test_write_date").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getDate(1) != null) + assert(result(1).getDate(1) != null) + assert(result(2).getDate(1) != null) + } + } + + test("write DateType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", DateType, nullable = true) + )) + + withTable("test_db", "test_write_date_null", schema) { + val data = Seq( + Row(1, java.sql.Date.valueOf("2024-01-01")), + Row(2, null), + Row(3, java.sql.Date.valueOf("2024-12-31")) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_date_null") + + val result = spark.table("test_db.test_write_date_null").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getDate(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getDate(1) != null) + } + } + + test("write DecimalType - Decimal(10,2)") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", DecimalType(10, 2), nullable = false) + )) + + withTable("test_db", "test_write_decimal", schema) { + val data = Seq( + Row(1, BigDecimal("12345.67")), + Row(2, BigDecimal("-9999.99")), + Row(3, BigDecimal("0.01")) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_decimal") + + val result = spark.table("test_db.test_write_decimal").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getDecimal(1) == BigDecimal("12345.67").underlying()) + assert(result(1).getDecimal(1) == BigDecimal("-9999.99").underlying()) + assert(result(2).getDecimal(1) == BigDecimal("0.01").underlying()) + } + } + + test("write DecimalType - Decimal(18,4)") { + // Note: High-precision decimals (>15-17 significant digits) may lose precision in JSON/Arrow formats. + // This appears to be related to the serialization/deserialization path, possibly due to intermediate + // double conversions in the format parsers. This test uses tolerance-based assertions to account + // for this observed behavior. Binary format (RowBinaryWithNamesAndTypes) preserves full precision. + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", DecimalType(18, 4), nullable = false) + )) + + withTable("test_db", "test_write_decimal_18_4", schema) { + val data = Seq( + Row(1, BigDecimal("12345678901234.5678")), + Row(2, BigDecimal("-9999999999999.9999")), + Row(3, BigDecimal("0.0001")) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_decimal_18_4") + + val result = spark.table("test_db.test_write_decimal_18_4").orderBy("id").collect() + assert(result.length == 3) + // Use tolerance for high-precision values (18 significant digits) + val tolerance = BigDecimal("0.001") + assert((BigDecimal(result(0).getDecimal(1)) - BigDecimal("12345678901234.5678")).abs < tolerance) + assert((BigDecimal(result(1).getDecimal(1)) - BigDecimal("-9999999999999.9999")).abs < tolerance) + // Small values should be exact + assert(result(2).getDecimal(1) == BigDecimal("0.0001").underlying()) + } + } + + test("write DecimalType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", DecimalType(10, 2), nullable = true) + )) + + withTable("test_db", "test_write_decimal_null", schema) { + val data = Seq( + Row(1, BigDecimal("12345.67")), + Row(2, null), + Row(3, BigDecimal("-9999.99")) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_decimal_null") + + val result = spark.table("test_db.test_write_decimal_null").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getDecimal(1) == BigDecimal("12345.67").underlying()) + assert(result(1).isNullAt(1)) + assert(result(2).getDecimal(1) == BigDecimal("-9999.99").underlying()) + } + } + + test("write DoubleType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", DoubleType, nullable = true) + )) + + withTable("test_db", "test_write_double_null", schema) { + val data = Seq( + Row(1, 3.14159), + Row(2, null), + Row(3, -2.71828) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_double_null") + + val result = spark.table("test_db.test_write_double_null").orderBy("id").collect() + assert(result.length == 3) + assert(math.abs(result(0).getDouble(1) - 3.14159) < 0.00001) + assert(result(1).isNullAt(1)) + assert(math.abs(result(2).getDouble(1) - -2.71828) < 0.00001) + } + } + + test("write DoubleType - regular values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", DoubleType, nullable = false) + )) + + withTable("test_db", "test_write_double", schema) { + val data = Seq( + Row(1, 3.14159), + Row(2, -2.71828), + Row(3, 0.0) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_double") + + val result = spark.table("test_db.test_write_double").orderBy("id").collect() + assert(result.length == 3) + assert(math.abs(result(0).getDouble(1) - 3.14159) < 0.00001) + assert(math.abs(result(1).getDouble(1) - -2.71828) < 0.00001) + assert(result(2).getDouble(1) == 0.0) + } + } + + test("write FloatType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", FloatType, nullable = true) + )) + + withTable("test_db", "test_write_float_null", schema) { + val data = Seq( + Row(1, 3.14f), + Row(2, null), + Row(3, -2.718f) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_float_null") + + val result = spark.table("test_db.test_write_float_null").orderBy("id").collect() + assert(result.length == 3) + assert(math.abs(result(0).getFloat(1) - 3.14f) < 0.001f) + assert(result(1).isNullAt(1)) + assert(math.abs(result(2).getFloat(1) - -2.718f) < 0.001f) + } + } + + test("write FloatType - regular values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", FloatType, nullable = false) + )) + + withTable("test_db", "test_write_float", schema) { + val data = Seq( + Row(1, 3.14f), + Row(2, -2.718f), + Row(3, 0.0f) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_float") + + val result = spark.table("test_db.test_write_float").orderBy("id").collect() + assert(result.length == 3) + assert(math.abs(result(0).getFloat(1) - 3.14f) < 0.001f) + assert(math.abs(result(1).getFloat(1) - -2.718f) < 0.001f) + assert(result(2).getFloat(1) == 0.0f) + } + } + + test("write IntegerType - min and max values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", IntegerType, nullable = false) + )) + + withTable("test_db", "test_write_int", schema) { + val data = Seq( + Row(1, Int.MinValue), + Row(2, 0), + Row(3, Int.MaxValue) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_int") + + val result = spark.table("test_db.test_write_int").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getInt(1) == Int.MinValue) + assert(result(1).getInt(1) == 0) + assert(result(2).getInt(1) == Int.MaxValue) + } + } + + test("write IntegerType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", IntegerType, nullable = true) + )) + + withTable("test_db", "test_write_int_null", schema) { + val data = Seq( + Row(1, Int.MinValue), + Row(2, null), + Row(3, Int.MaxValue) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_int_null") + + val result = spark.table("test_db.test_write_int_null").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getInt(1) == Int.MinValue) + assert(result(1).isNullAt(1)) + assert(result(2).getInt(1) == Int.MaxValue) + } + } + + test("write LongType - min and max values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", LongType, nullable = false) + )) + + withTable("test_db", "test_write_long", schema) { + val data = Seq( + Row(1, Long.MinValue), + Row(2, 0L), + Row(3, Long.MaxValue) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_long") + + val result = spark.table("test_db.test_write_long").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getLong(1) == Long.MinValue) + assert(result(1).getLong(1) == 0L) + assert(result(2).getLong(1) == Long.MaxValue) + } + } + + test("write LongType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", LongType, nullable = true) + )) + + withTable("test_db", "test_write_long_null", schema) { + val data = Seq( + Row(1, Long.MinValue), + Row(2, null), + Row(3, Long.MaxValue) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_long_null") + + val result = spark.table("test_db.test_write_long_null").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getLong(1) == Long.MinValue) + assert(result(1).isNullAt(1)) + assert(result(2).getLong(1) == Long.MaxValue) + } + } + + test("write MapType - empty maps") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", MapType(StringType, IntegerType, valueContainsNull = false), nullable = false) + )) + + withTable("test_db", "test_write_empty_map", schema) { + val data = Seq( + Row(1, Map[String, Int]()), + Row(2, Map("a" -> 1)), + Row(3, Map[String, Int]()) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_empty_map") + + val result = spark.table("test_db.test_write_empty_map").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getMap[String, Int](1).isEmpty) + assert(result(1).getMap[String, Int](1) == Map("a" -> 1)) + assert(result(2).getMap[String, Int](1).isEmpty) + } + } + + test("write MapType - map of string to int") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", MapType(StringType, IntegerType, valueContainsNull = false), nullable = false) + )) + + withTable("test_db", "test_write_map", schema) { + val data = Seq( + Row(1, Map("a" -> 1, "b" -> 2)), + Row(2, Map("x" -> 10, "y" -> 20)), + Row(3, Map("foo" -> 100)) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_map") + + val result = spark.table("test_db.test_write_map").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getMap[String, Int](1) == Map("a" -> 1, "b" -> 2)) + assert(result(1).getMap[String, Int](1) == Map("x" -> 10, "y" -> 20)) + assert(result(2).getMap[String, Int](1) == Map("foo" -> 100)) + } + } + + test("write MapType - with nullable values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", MapType(StringType, IntegerType, valueContainsNull = true), nullable = false) + )) + + withTable("test_db", "test_write_map_nullable", schema) { + val data = Seq( + Row(1, Map("a" -> 1, "b" -> null)), + Row(2, Map("x" -> null, "y" -> 20)), + Row(3, Map("foo" -> 100)) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_map_nullable") + + val result = spark.table("test_db.test_write_map_nullable").orderBy("id").collect() + assert(result.length == 3) + val map1 = result(0).getMap[String, Any](1) + assert(map1("a") == 1) + assert(map1("b") == null) + } + } + + test("write ShortType - min and max values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", ShortType, nullable = false) + )) + + withTable("test_db", "test_write_short", schema) { + val data = Seq( + Row(1, Short.MinValue), + Row(2, 0.toShort), + Row(3, Short.MaxValue) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_short") + + val result = spark.table("test_db.test_write_short").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getShort(1) == Short.MinValue) + assert(result(1).getShort(1) == 0.toShort) + assert(result(2).getShort(1) == Short.MaxValue) + } + } + + test("write ShortType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", ShortType, nullable = true) + )) + + withTable("test_db", "test_write_short_null", schema) { + val data = Seq( + Row(1, Short.MinValue), + Row(2, null), + Row(3, Short.MaxValue) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_short_null") + + val result = spark.table("test_db.test_write_short_null").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getShort(1) == Short.MinValue) + assert(result(1).isNullAt(1)) + assert(result(2).getShort(1) == Short.MaxValue) + } + } + + test("write StringType - empty strings") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", StringType, nullable = false) + )) + + withTable("test_db", "test_write_empty_string", schema) { + val data = Seq( + Row(1, ""), + Row(2, "not empty"), + Row(3, "") + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_empty_string") + + val result = spark.table("test_db.test_write_empty_string").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getString(1) == "") + assert(result(1).getString(1) == "not empty") + assert(result(2).getString(1) == "") + } + } + + test("write StringType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", StringType, nullable = true) + )) + + withTable("test_db", "test_write_string_null", schema) { + val data = Seq( + Row(1, "hello"), + Row(2, null), + Row(3, "world") + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_string_null") + + val result = spark.table("test_db.test_write_string_null").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getString(1) == "hello") + assert(result(1).isNullAt(1)) + assert(result(2).getString(1) == "world") + } + } + + test("write StringType - regular strings") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", StringType, nullable = false) + )) + + withTable("test_db", "test_write_string", schema) { + val data = Seq( + Row(1, "hello"), + Row(2, "world"), + Row(3, "test") + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_string") + + val result = spark.table("test_db.test_write_string").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getString(1) == "hello") + assert(result(1).getString(1) == "world") + assert(result(2).getString(1) == "test") + } + } + + test("write TimestampType - nullable with null values") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", TimestampType, nullable = true) + )) + + withTable("test_db", "test_write_timestamp_null", schema) { + val data = Seq( + Row(1, java.sql.Timestamp.valueOf("2024-01-01 12:00:00")), + Row(2, null), + Row(3, java.sql.Timestamp.valueOf("2024-12-31 23:59:59")) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_timestamp_null") + + val result = spark.table("test_db.test_write_timestamp_null").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getTimestamp(1) != null) + assert(result(1).isNullAt(1)) + assert(result(2).getTimestamp(1) != null) + } + } + + test("write TimestampType - timestamps") { + val schema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", TimestampType, nullable = false) + )) + + withTable("test_db", "test_write_timestamp", schema) { + val data = Seq( + Row(1, java.sql.Timestamp.valueOf("2024-01-01 12:00:00")), + Row(2, java.sql.Timestamp.valueOf("2024-06-15 18:30:45")), + Row(3, java.sql.Timestamp.valueOf("2024-12-31 23:59:59")) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.mode(SaveMode.Append).saveAsTable("test_db.test_write_timestamp") + + val result = spark.table("test_db.test_write_timestamp").orderBy("id").collect() + assert(result.length == 3) + assert(result(0).getTimestamp(1) != null) + assert(result(1).getTimestamp(1) != null) + assert(result(2).getTimestamp(1) != null) + } + } + +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/SparkClickHouseSingleTest.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/SparkClickHouseSingleTest.scala new file mode 100644 index 00000000..d9e7890a --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/SparkClickHouseSingleTest.scala @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.ClickHouseProvider +import org.apache.spark.SparkConf +import org.apache.spark.sql.clickhouse.SparkTest +import org.apache.spark.sql.functions.month +import org.apache.spark.sql.types.StructType + +trait SparkClickHouseSingleTest extends SparkTest with ClickHouseProvider { + + import testImplicits._ + + override protected def sparkConf: SparkConf = super.sparkConf + .setMaster("local[2]") + .setAppName("spark-clickhouse-single-ut") + .set("spark.sql.shuffle.partitions", "2") + // catalog + .set("spark.sql.defaultCatalog", "clickhouse") + .set("spark.sql.catalog.clickhouse", "com.clickhouse.spark.ClickHouseCatalog") + .set("spark.sql.catalog.clickhouse.host", clickhouseHost) + .set("spark.sql.catalog.clickhouse.http_port", clickhouseHttpPort.toString) + .set("spark.sql.catalog.clickhouse.protocol", "http") + .set("spark.sql.catalog.clickhouse.user", clickhouseUser) + .set("spark.sql.catalog.clickhouse.password", clickhousePassword) + .set("spark.sql.catalog.clickhouse.database", clickhouseDatabase) + .set("spark.sql.catalog.clickhouse.option.custom_http_params", "async_insert=1,wait_for_async_insert=1") + .set("spark.sql.catalog.clickhouse.option.ssl", isSslEnabled.toString) + // extended configurations + .set("spark.clickhouse.write.batchSize", "2") + .set("spark.clickhouse.write.maxRetry", "2") + .set("spark.clickhouse.write.retryInterval", "1") + .set("spark.clickhouse.write.retryableErrorCodes", "241") + .set("spark.clickhouse.write.write.repartitionNum", "0") + .set("spark.clickhouse.read.format", "json") + .set("spark.clickhouse.write.format", "json") + + override def cmdRunnerOptions: Map[String, String] = Map( + "host" -> clickhouseHost, + "http_port" -> clickhouseHttpPort.toString, + "protocol" -> "http", + "user" -> clickhouseUser, + "password" -> clickhousePassword, + "database" -> clickhouseDatabase, + "option.custom_http_params" -> "async_insert=1,wait_for_async_insert=1", + "option.ssl" -> isSslEnabled.toString + ) + + def withTable( + db: String, + tbl: String, + schema: StructType, + engine: String = "MergeTree()", + sortKeys: Seq[String] = "id" :: Nil, + partKeys: Seq[String] = Seq.empty + )(f: => Unit): Unit = + try { + runClickHouseSQL(s"CREATE DATABASE IF NOT EXISTS $db") + + spark.sql( + s"""CREATE TABLE $db.$tbl ( + | ${schema.fields.map(_.toDDL).mkString(",\n ")} + |) USING ClickHouse + |${if (partKeys.isEmpty) "" else partKeys.mkString("PARTITIONED BY(", ", ", ")")} + |TBLPROPERTIES ( + | ${if (sortKeys.isEmpty) "" else sortKeys.mkString("order_by = '", ", ", "',")} + | engine = '$engine' + |) + |""".stripMargin + ) + + f + } finally { + runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl") + runClickHouseSQL(s"DROP DATABASE IF EXISTS $db") + } + + def withKVTable( + db: String, + tbl: String, + keyColDef: String = "Int32", + valueColDef: String + )(f: => Unit): Unit = + try { + runClickHouseSQL(s"CREATE DATABASE IF NOT EXISTS $db") + runClickHouseSQL( + s"""CREATE TABLE $db.$tbl ( + | key $keyColDef, + | value $valueColDef + |) ENGINE = MergeTree() + |ORDER BY key + |""".stripMargin + ) + f + } finally { + runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl") + runClickHouseSQL(s"DROP DATABASE IF EXISTS $db") + } + + def withSimpleTable( + db: String, + tbl: String, + writeData: Boolean = false + )(f: => Unit): Unit = + try { + runClickHouseSQL(s"CREATE DATABASE IF NOT EXISTS $db") + + // SPARK-33779: Spark 3.3 only support IdentityTransform + spark.sql( + s"""CREATE TABLE $db.$tbl ( + | id BIGINT NOT NULL COMMENT 'sort key', + | value STRING, + | create_time TIMESTAMP NOT NULL, + | m INT NOT NULL COMMENT 'part key' + |) USING ClickHouse + |PARTITIONED BY (m) + |TBLPROPERTIES ( + | engine = 'MergeTree()', + | order_by = 'id' + |) + |""".stripMargin + ) + + if (writeData) { + val tblSchema = spark.table(s"$db.$tbl").schema + val dataDF = spark.createDataFrame(Seq( + (1L, "1", timestamp("2021-01-01T10:10:10Z")), + (2L, "2", timestamp("2022-02-02T10:10:10Z")) + )).toDF("id", "value", "create_time") + .withColumn("m", month($"create_time")) + .select($"id", $"value", $"create_time", $"m") + + spark.createDataFrame(dataDF.rdd, tblSchema) + .writeTo(s"$db.$tbl") + .append + } + + f + } finally { + runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl") + runClickHouseSQL(s"DROP DATABASE IF EXISTS $db") + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/TPCDSSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/TPCDSSuite.scala new file mode 100644 index 00000000..cad773c1 --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/TPCDSSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.{ClickHouseCloudMixIn, ClickHouseSingleMixIn} +import org.apache.spark.SparkConf +import org.apache.spark.sql.clickhouse.TPCDSTestUtils +import org.scalatest.tags.{Cloud, Slow} + +@Cloud +class ClickHouseCloudTPCDSSuite extends TPCDSSuite with ClickHouseCloudMixIn + +@Slow +class ClickHouseSingleTPCDSSuite extends TPCDSSuite with ClickHouseSingleMixIn + +abstract class TPCDSSuite extends SparkClickHouseSingleTest { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.catalog.tpcds", "org.apache.kyuubi.spark.connector.tpcds.TPCDSCatalog") + .set("spark.sql.catalog.clickhouse.protocol", "http") + .set("spark.clickhouse.read.compression.codec", "none") + .set("spark.clickhouse.write.batchSize", "100000") + .set("spark.clickhouse.write.compression.codec", "none") + + test("TPC-DS tiny write and count(*)") { + withDatabase("tpcds_tiny") { + spark.sql("CREATE DATABASE tpcds_tiny") + + TPCDSTestUtils.tablePrimaryKeys.foreach { case (table, primaryKeys) => + spark.sql( + s""" + |CREATE TABLE tpcds_tiny.$table + |USING clickhouse + |TBLPROPERTIES ( + | order_by = '${primaryKeys.mkString(",")}', + | 'settings.allow_nullable_key' = 1 + |) + |SELECT * FROM tpcds.tiny.$table; + |""".stripMargin + ) + } + + TPCDSTestUtils.tablePrimaryKeys.keys.foreach { table => + assert(spark.table(s"tpcds_tiny.$table").count === spark.table(s"tpcds.tiny.$table").count) + } + } + } +} diff --git a/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/WriteDistributionAndOrderingSuite.scala b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/WriteDistributionAndOrderingSuite.scala new file mode 100644 index 00000000..6469c07d --- /dev/null +++ b/spark-4.0/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/WriteDistributionAndOrderingSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse.single + +import com.clickhouse.spark.base.{ClickHouseCloudMixIn, ClickHouseSingleMixIn} +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.Row +import org.scalatest.tags.Cloud + +@Cloud +class ClickHouseCloudsWriteDistributionAndOrderingSuite + extends WriteDistributionAndOrderingSuite with ClickHouseCloudMixIn + +class ClickHouseSinglesWriteDistributionAndOrderingSuite + extends WriteDistributionAndOrderingSuite with ClickHouseSingleMixIn + +abstract class WriteDistributionAndOrderingSuite extends SparkClickHouseSingleTest { + + import testImplicits._ + + private val db = "db_distribution_and_ordering" + private val tbl = "tbl_distribution_and_ordering" + + private def write(): Unit = spark.range(3) + .toDF("id") + .withColumn("id", $"id".cast(StringType)) + .withColumn("load_date", lit(date("2022-05-27"))) + .writeTo(s"$db.$tbl") + .append + + private def check(): Unit = checkAnswer( + spark.sql(s"SELECT id, load_date FROM $db.$tbl"), + Seq( + Row("0", date("2022-05-27")), + Row("1", date("2022-05-27")), + Row("2", date("2022-05-27")) + ) + ) + + override protected def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE DATABASE IF NOT EXISTS `$db`") + runClickHouseSQL( + s"""CREATE TABLE `$db`.`$tbl` ( + | `id` String, + | `load_date` Date + |) ENGINE = MergeTree + |ORDER BY load_date + |PARTITION BY xxHash64(id) + |""".stripMargin + ) + } + + override protected def afterAll(): Unit = { + sql(s"DROP TABLE IF EXISTS `$db`.`$tbl`") + sql(s"DROP DATABASE IF EXISTS `$db`") + super.afterAll() + } + + override protected def beforeEach(): Unit = { + sql(s"TRUNCATE TABLE `$db`.`$tbl`") + super.beforeEach() + } + + def writeDataToTablesContainsUnsupportedPartitions( + ignoreUnsupportedTransform: Boolean, + repartitionByPartition: Boolean, + localSortByKey: Boolean + ): Unit = withSQLConf( + IGNORE_UNSUPPORTED_TRANSFORM.key -> ignoreUnsupportedTransform.toString, + WRITE_REPARTITION_BY_PARTITION.key -> repartitionByPartition.toString, + WRITE_LOCAL_SORT_BY_KEY.key -> localSortByKey.toString + ) { + write() + check() + } + + Seq(true, false).foreach { ignoreUnsupportedTransform => + Seq(true, false).foreach { repartitionByPartition => + Seq(true, false).foreach { localSortByKey => + test("write data to table contains unsupported partitions - " + + s"ignoreUnsupportedTransform=$ignoreUnsupportedTransform " + + s"repartitionByPartition=$repartitionByPartition " + + s"localSortByKey=$localSortByKey") { + writeDataToTablesContainsUnsupportedPartitions( + ignoreUnsupportedTransform, + repartitionByPartition, + localSortByKey + ) + } + } + } + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseCatalog.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseCatalog.scala new file mode 100644 index 00000000..06f0e1ff --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseCatalog.scala @@ -0,0 +1,388 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark + +import com.clickhouse.client.ClickHouseProtocol +import com.clickhouse.spark.exception.CHClientException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.clickhouse.{ExprUtils, SchemaUtils} +import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import Constants._ +import com.clickhouse.spark.exception.ClickHouseErrCode._ +import com.clickhouse.spark.func.{ + ClickHouseXxHash64Shard, + CompositeFunctionRegistry, + DynamicFunctionRegistry, + FunctionRegistry, + StaticFunctionRegistry +} +import com.clickhouse.spark.spec.{ClusterSpec, DistributedEngineSpec, NodeSpec, TableEngineUtils} +import com.clickhouse.spark.func._ +import com.clickhouse.spark.client.NodeClient +import com.clickhouse.spark.spec._ + +import java.time.ZoneId +import java.util +import scala.collection.JavaConverters._ + +class ClickHouseCatalog extends TableCatalog + with SupportsNamespaces + with FunctionCatalog + with ClickHouseHelper + with SQLHelper + with Logging { + + private var catalogName: String = _ + + // /////////////////////////////////////////////////// + // ////////////////// SINGLE NODE //////////////////// + // /////////////////////////////////////////////////// + private var nodeSpec: NodeSpec = _ + + implicit private var nodeClient: NodeClient = _ + + // case Left => server timezone + // case Right => client timezone or user specific timezone + private var tz: Either[ZoneId, ZoneId] = _ + + private var currentDb: String = _ + + // /////////////////////////////////////////////////// + // /////////////////// CLUSTERS ////////////////////// + // /////////////////////////////////////////////////// + private var clusterSpecs: Seq[ClusterSpec] = Nil + + private var functionRegistry: FunctionRegistry = _ + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + this.catalogName = name + this.nodeSpec = buildNodeSpec(options) + this.currentDb = nodeSpec.database + this.nodeClient = NodeClient(nodeSpec) + + this.nodeClient.syncQueryAndCheckOutputJSONEachRow("SELECT 1") + + this.tz = options.get(CATALOG_PROP_TZ) match { + case tz if tz == null || tz.isEmpty || tz.toLowerCase == "server" => + val timezoneOutput = this.nodeClient.syncQueryAndCheckOutputJSONEachRow("SELECT timezone() AS tz") + assert(timezoneOutput.rows == 1) + val serverTz = ZoneId.of(timezoneOutput.records.head.get("tz").asText) + log.info(s"Detect ClickHouse server timezone: $serverTz") + Left(serverTz) + case tz if tz.toLowerCase == "client" => Right(ZoneId.systemDefault) + case tz => Right(ZoneId.of(tz)) + } + + this.clusterSpecs = queryClusterSpecs(nodeSpec) + + val dynamicFunctionRegistry = new DynamicFunctionRegistry + val xxHash64ShardFunc = new ClickHouseXxHash64Shard(clusterSpecs) + dynamicFunctionRegistry.register("ck_xx_hash64_shard", xxHash64ShardFunc) // for compatible + dynamicFunctionRegistry.register("clickhouse_shard_xxHash64", xxHash64ShardFunc) + this.functionRegistry = new CompositeFunctionRegistry(Array(StaticFunctionRegistry, dynamicFunctionRegistry)) + + log.info(s"Detect ${clusterSpecs.size} ClickHouse clusters: ${clusterSpecs.map(_.name).mkString(",")}") + log.info(s"ClickHouse clusters' detail: $clusterSpecs") + log.info(s"Registered functions: ${this.functionRegistry.list.mkString(",")}") + } + + override def name(): String = catalogName + + @throws[NoSuchNamespaceException] + override def listTables(namespace: Array[String]): Array[Identifier] = namespace match { + case Array(database) => + nodeClient.syncQueryOutputJSONEachRow(s"SHOW TABLES IN ${quoted(database)}") match { + case Left(exception) if exception.code == UNKNOWN_DATABASE.code => + throw new NoSuchNamespaceException(namespace) + case Left(rethrow) => + throw rethrow + case Right(output) => + output.records + .map(row => row.get("name").asText) + .map(table => Identifier.of(namespace, table)) + .toArray + } + case _ => throw new NoSuchNamespaceException(namespace) + } + + @throws[NoSuchTableException] + override def loadTable(ident: Identifier): ClickHouseTable = { + val (database, table) = unwrap(ident) match { + case None => throw new NoSuchTableException(ident) + case Some((db, tbl)) => + nodeClient.syncQueryOutputJSONEachRow(s"SELECT * FROM `$db`.`$tbl` WHERE 1=0") match { + case Left(exception) if exception.code == UNKNOWN_TABLE.code => + throw new NoSuchTableException(ident) + // not sure if this check is necessary + case Left(exception) if exception.code == UNKNOWN_DATABASE.code => + throw new NoSuchTableException(Array(db)) + case Left(rethrow) => + throw rethrow + case Right(_) => (db, tbl) + } + } + implicit val _tz: ZoneId = tz.merge + val tableSpec = queryTableSpec(database, table) + val tableEngineSpec = TableEngineUtils.resolveTableEngine(tableSpec) + val tableClusterSpec = tableEngineSpec match { + case distributeSpec: DistributedEngineSpec => + Some(TableEngineUtils.resolveTableCluster(distributeSpec, clusterSpecs)) + case _ => None + } + ClickHouseTable( + nodeSpec, + tableClusterSpec, + _tz, + tableSpec, + tableEngineSpec, + functionRegistry + ) + } + + /** + *

MergeTree Engine

+ * {{{ + * CREATE TABLE [IF NOT EXISTS] [db.]table_name [ON CLUSTER cluster] + * ( + * name1 [type1] [DEFAULT|MATERIALIZED|ALIAS expr1] [TTL expr1], + * name2 [type2] [DEFAULT|MATERIALIZED|ALIAS expr2] [TTL expr2], + * ... + * INDEX index_name1 expr1 TYPE type1(...) GRANULARITY value1, + * INDEX index_name2 expr2 TYPE type2(...) GRANULARITY value2 + * ) ENGINE = MergeTree() + * ORDER BY expr + * [PARTITION BY expr] + * [PRIMARY KEY expr] + * [SAMPLE BY expr] + * [TTL expr + * [DELETE|TO DISK 'xxx'|TO VOLUME 'xxx' [, ...] ] + * [WHERE conditions] + * [GROUP BY key_expr [SET v1 = agg_func(v1) [, v2 = agg_func(v2) ...]] ]] + * [SETTINGS name=value, ...] + * }}} + *

+ * + *

ReplacingMergeTree Engine

+ * {{{ + * CREATE TABLE [IF NOT EXISTS] [db.]table_name [ON CLUSTER cluster] + * ( + * name1 [type1] [DEFAULT|MATERIALIZED|ALIAS expr1], + * name2 [type2] [DEFAULT|MATERIALIZED|ALIAS expr2], + * ... + * ) ENGINE = ReplacingMergeTree([ver]) + * [PARTITION BY expr] + * [ORDER BY expr] + * [PRIMARY KEY expr] + * [SAMPLE BY expr] + * [SETTINGS name=value, ...] + * }}} + * + * `ver` — column with version. Type `UInt*`, `Date` or `DateTime`. + */ + @throws[TableAlreadyExistsException] + @throws[NoSuchNamespaceException] + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String] + ): ClickHouseTable = { + val (db, tbl) = unwrap(ident) match { + case Some((d, t)) => (d, t) + case None => throw CHClientException(s"Invalid table identifier: $ident") + } + val props = properties.asScala + + val engineExpr = props.getOrElse("engine", "MergeTree()") + + val isCreatingDistributed = engineExpr equalsIgnoreCase "Distributed" + val keyPrefix = if (isCreatingDistributed) "local." else "" + + val partitionsClause = partitions match { + case transforms if transforms.nonEmpty => + transforms.map(ExprUtils.toClickHouse(_, functionRegistry).sql).mkString("PARTITION BY (", ", ", ")") + case _ => "" + } + + val orderClause = props.get(s"${keyPrefix}order_by").map(o => s"ORDER BY ($o)").getOrElse("") + val primaryKeyClause = props.get(s"${keyPrefix}primary_key").map(p => s"PRIMARY KEY ($p)").getOrElse("") + val sampleClause = props.get(s"${keyPrefix}sample_by").map(p => s"SAMPLE BY ($p)").getOrElse("") + + val fieldsClause = SchemaUtils + .toClickHouseSchema(schema) + .map { case (fieldName, ckType, comment) => s"${quoted(fieldName)} $ckType$comment" } + .mkString(",\n ") + + val clusterOpt = props.get("cluster") + + def tblSettingsClause(prefix: String): String = props.filterKeys(_.startsWith(prefix)) match { + case settings if settings.nonEmpty => + settings.map { case (k, v) => + s"${k.substring(prefix.length)}=$v" + }.mkString("SETTINGS ", ", ", "") + case _ => "" + } + + def createTable( + clusterOpt: Option[String], + engineExpr: String, + database: String, + table: String, + settingsClause: String + ): Unit = { + val clusterClause = clusterOpt.map(c => s"ON CLUSTER $c").getOrElse("") + nodeClient.syncQueryAndCheckOutputJSONEachRow( + s"""CREATE TABLE `$database`.`$table` $clusterClause ( + |$fieldsClause + |) ENGINE = $engineExpr + |$partitionsClause + |$orderClause + |$primaryKeyClause + |$sampleClause + |$settingsClause + |""".stripMargin + .replaceAll("""\n\s+\n""", "\n") // remove empty lines + ) + } + + def createDistributedTable( + cluster: String, + shardExpr: String, + localDatabase: String, + localTable: String, + distributedDatabase: String, + distributedTable: String, + settingsClause: String + ): Unit = nodeClient.syncQueryAndCheckOutputJSONEachRow( + s"""CREATE TABLE `$distributedDatabase`.`$distributedTable` ON CLUSTER $cluster + |AS `$localDatabase`.`$localTable` + |ENGINE = Distributed($cluster, '$localDatabase', '$localTable', ($shardExpr)) + |$settingsClause + |""".stripMargin + ) + + if (isCreatingDistributed) { + val cluster = clusterOpt.getOrElse("default") + val shardExpr = props.getOrElse("shard_by", "rand()") + val settingsClause = tblSettingsClause("settings.") + val localEngineExpr = props.getOrElse(s"${keyPrefix}engine", s"MergeTree()") + val localDatabase = props.getOrElse(s"${keyPrefix}database", db) + val localTable = props.getOrElse(s"${keyPrefix}table", s"${tbl}_local") + val localSettingsClause = tblSettingsClause(s"${keyPrefix}settings.") + createTable(Some(cluster), localEngineExpr, localDatabase, localTable, localSettingsClause) + createDistributedTable(cluster, shardExpr, localDatabase, localTable, db, tbl, settingsClause) + } else { + val settingsClause = tblSettingsClause(s"${keyPrefix}settings.") + createTable(clusterOpt, engineExpr, db, tbl, settingsClause) + } + + loadTable(ident) + } + + @throws[NoSuchTableException] + override def alterTable(ident: Identifier, changes: TableChange*): ClickHouseTable = + throw new UnsupportedOperationException + + override def dropTable(ident: Identifier): Boolean = { + val tableOpt = + try Some(loadTable(ident)) + catch { + case _: NoSuchTableException => None + } + tableOpt match { + case None => false + case Some(ClickHouseTable(_, cluster, _, tableSpec, _, _)) => + val (db, tbl) = (tableSpec.database, tableSpec.name) + val isAtomic = loadNamespaceMetadata(Array(db)).get("engine").equalsIgnoreCase("atomic") + val syncClause = if (isAtomic) "SYNC" else "" + // limitation: only support Distribute table, can not handle cases such as drop local table on cluster nodes + val clusterClause = cluster.map(c => s"ON CLUSTER ${c.name}").getOrElse("") + nodeClient.syncQueryOutputJSONEachRow(s"DROP TABLE `$db`.`$tbl` $clusterClause $syncClause").isRight + } + } + + @throws[NoSuchTableException] + @throws[TableAlreadyExistsException] + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = + (unwrap(oldIdent), unwrap(newIdent)) match { + case (Some((oldDb, oldTbl)), Some((newDb, newTbl))) => + nodeClient.syncQueryOutputJSONEachRow(s"RENAME TABLE `$oldDb`.`$oldTbl` to `$newDb`.`$newTbl`") match { + case Left(exception) => throw new NoSuchTableException( + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + messageParameters = Map("relationName" -> oldIdent.toString), + cause = Some(exception) + ) + case Right(_) => + } + case _ => throw CHClientException("Invalid table identifier") + } + + override def defaultNamespace(): Array[String] = Array(currentDb) + + @throws[NoSuchNamespaceException] + override def listNamespaces(): Array[Array[String]] = { + val output = nodeClient.syncQueryAndCheckOutputJSONEachRow("SHOW DATABASES") + output.records.map(row => Array(row.get("name").asText)).toArray + } + + @throws[NoSuchNamespaceException] + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = namespace match { + case Array() => listNamespaces() + case Array(_) => + loadNamespaceMetadata(namespace) + Array() + case _ => throw new NoSuchNamespaceException(namespace) + } + + @throws[NoSuchNamespaceException] + override def loadNamespaceMetadata(namespace: Array[String]): util.Map[String, String] = namespace match { + case Array(database) => queryDatabaseSpec(database).toJavaMap + case _ => throw new NoSuchNamespaceException(namespace) + } + + @throws[NamespaceAlreadyExistsException] + override def createNamespace(namespace: Array[String], metadata: util.Map[String, String]): Unit = namespace match { + case Array(database) => + val onClusterClause = metadata.asScala.get("cluster").map(c => s"ON CLUSTER $c").getOrElse("") + nodeClient.syncQueryOutputJSONEachRow(s"CREATE DATABASE ${quoted(database)} $onClusterClause") + } + + @throws[NoSuchNamespaceException] + override def alterNamespace(namespace: Array[String], changes: NamespaceChange*): Unit = + throw new UnsupportedOperationException("ALTER NAMESPACE OPERATION is unsupported yet") + + @throws[NoSuchNamespaceException] + override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = namespace match { + case Array(database) => + loadNamespaceMetadata(namespace) // test existing + if (!cascade && listNamespaces(namespace).nonEmpty) { + throw new NonEmptyNamespaceException(namespace) + } + nodeClient.syncQueryOutputJSONEachRow(s"DROP DATABASE ${quoted(database)}").isRight + case _ => false + } + + @throws[NoSuchNamespaceException] + override def listFunctions(namespace: Array[String]): Array[Identifier] = + functionRegistry.list.map(name => Identifier.of(Array.empty, name)) + + @throws[NoSuchFunctionException] + override def loadFunction(ident: Identifier): UnboundFunction = + functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident)) +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseCommandRunner.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseCommandRunner.scala new file mode 100644 index 00000000..ff486351 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseCommandRunner.scala @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark + +import com.clickhouse.spark.client.NodeClient +import org.apache.spark.sql.connector.ExternalCommandRunner +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class ClickHouseCommandRunner extends ExternalCommandRunner with ClickHouseHelper { + + override def executeCommand(sql: String, options: CaseInsensitiveStringMap): Array[String] = + Utils.tryWithResource(client.NodeClient(buildNodeSpec(options))) { nodeClient => + nodeClient.syncQueryAndCheckOutputJSONEachRow(sql).records.map(_.toString).toArray + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseHelper.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseHelper.scala new file mode 100644 index 00000000..a96d0505 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseHelper.scala @@ -0,0 +1,359 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark + +import com.clickhouse.client.ClickHouseProtocol +import com.clickhouse.spark.exception.CHException +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.NullNode +import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException} +import org.apache.spark.sql.clickhouse.SchemaUtils +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import Constants._ +import com.clickhouse.spark.client.NodeClient +import Utils.dateTimeFmt +import com.clickhouse.spark.spec.{ + ClusterSpec, + DatabaseSpec, + NoPartitionSpec, + NodeSpec, + PartitionSpec, + ReplicaSpec, + ShardSpec, + TableSpec +} +import com.clickhouse.spark.spec._ + +import java.time.{LocalDateTime, ZoneId} +import java.util.{HashMap => JHashMap} +import scala.collection.JavaConverters._ + +trait ClickHouseHelper extends Logging { + + @volatile lazy val DEFAULT_ACTION_IF_NO_SUCH_DATABASE: String => Unit = + (db: String) => throw new NoSuchNamespaceException(Array(db)) + + @volatile lazy val DEFAULT_ACTION_IF_NO_SUCH_TABLE: (String, String) => Unit = + (database, table) => throw new NoSuchTableException(database, table) + + def unwrap(ident: Identifier): Option[(String, String)] = ident.namespace() match { + case Array(database) => Some((database, ident.name())) + case _ => None + } + + def buildNodeSpec(options: CaseInsensitiveStringMap): NodeSpec = { + val clientOpts = options.asScala + .filterKeys(_.startsWith(CATALOG_PROP_OPTION_PREFIX)) + .map { case (k, v) => k.substring(CATALOG_PROP_OPTION_PREFIX.length) -> v } + .toMap + .filterKeys { key => + val ignore = CATALOG_PROP_IGNORE_OPTIONS.contains(key) + if (ignore) { + log.warn(s"Ignore configuration $key.") + } + !ignore + } + .toMap + NodeSpec( + _host = options.getOrDefault(CATALOG_PROP_HOST, "localhost"), + _tcp_port = Some(options.getInt(CATALOG_PROP_TCP_PORT, 9000)), + _http_port = Some(options.getInt(CATALOG_PROP_HTTP_PORT, 8123)), + protocol = ClickHouseProtocol.fromUriScheme(options.getOrDefault(CATALOG_PROP_PROTOCOL, "http")), + username = options.getOrDefault(CATALOG_PROP_USER, "default"), + password = options.getOrDefault(CATALOG_PROP_PASSWORD, ""), + database = options.getOrDefault(CATALOG_PROP_DATABASE, "default"), + infer_runtime_env = options.getOrDefault(CATALOG_INFER_RUNTIME_ENV, "true"), + options = new JHashMap(clientOpts.asJava) + ) + } + + def queryClusterSpecs(nodeSpec: NodeSpec)(implicit nodeClient: NodeClient): Seq[ClusterSpec] = { + val clustersOutput = nodeClient.syncQueryAndCheckOutputJSONEachRow( + """ SELECT + | `cluster`, -- String + | `shard_num`, -- UInt32 + | `shard_weight`, -- UInt32 + | `replica_num`, -- UInt32 + | `host_name`, -- String + | `host_address`, -- String + | `port`, -- UInt16 + | `is_local`, -- UInt8 + | `user`, -- String + | `default_database`, -- String + | `errors_count`, -- UInt32 + | `estimated_recovery_time` -- UInt32 + | FROM `system`.`clusters` + |""".stripMargin + ) + clustersOutput.records + .groupBy(_.get("cluster").asText) + .map { case (cluster, rows) => + val shards = rows + .groupBy(_.get("shard_num").asInt) + .map { case (shardNum, rows) => + val shardWeight = rows.head.get("shard_weight").asInt + val nodes = rows.map { row => + val replicaNum = row.get("replica_num").asInt + // should other properties be provided by `SparkConf`? + val clickhouseNode = nodeSpec.copy( + // host_address is not works for testcontainers + _host = row.get("host_name").asText, + _tcp_port = Some(row.get("port").asInt), + _http_port = if (Utils.isTesting) Some(8123) else nodeSpec.http_port + ) + ReplicaSpec(replicaNum, clickhouseNode) + }.toArray + ShardSpec(shardNum, shardWeight, nodes) + }.toArray + ClusterSpec(cluster, shards) + }.toSeq + } + + def queryDatabaseSpec( + database: String, + actionIfNoSuchDatabase: String => Unit = DEFAULT_ACTION_IF_NO_SUCH_DATABASE + )(implicit nodeClient: NodeClient): DatabaseSpec = { + val output = nodeClient.syncQueryAndCheckOutputJSONEachRow( + s"""SELECT + | `name`, -- String + | `engine`, -- String + | `data_path`, -- String + | `metadata_path`, -- String + | `uuid` -- String + |FROM `system`.`databases` + |WHERE `name`='$database' + |""".stripMargin + ) + if (output.rows == 0) { + actionIfNoSuchDatabase(database) + } + val row = output.records.head + DatabaseSpec( + name = row.get("name").asText, + engine = row.get("engine").asText, + data_path = row.get("data_path").asText, + metadata_path = row.get("metadata_path").asText, + uuid = row.get("uuid").asText + ) + } + + def queryTableSpec( + database: String, + table: String, + actionIfNoSuchTable: (String, String) => Unit = DEFAULT_ACTION_IF_NO_SUCH_TABLE + )(implicit + nodeClient: NodeClient, + tz: ZoneId + ): TableSpec = { + val tableOutput = nodeClient.syncQueryAndCheckOutputJSONEachRow( + s"""SELECT + | `database`, -- String + | `name`, -- String + | `uuid`, -- UUID + | `engine`, -- String + | `is_temporary`, -- UInt8 + | `data_paths`, -- Array(String) + | `metadata_path`, -- String + | `metadata_modification_time`, -- DateTime + | `dependencies_database`, -- Array(String) + | `dependencies_table`, -- Array(String) + | `create_table_query`, -- String + | `engine_full`, -- String + | `partition_key`, -- String + | `sorting_key`, -- String + | `primary_key`, -- String + | `sampling_key`, -- String + | `storage_policy`, -- String + | `total_rows`, -- Nullable(UInt64) + | `total_bytes`, -- Nullable(UInt64) + | `lifetime_rows`, -- Nullable(UInt64) + | `lifetime_bytes` -- Nullable(UInt64) + |FROM `system`.`tables` + |WHERE `database`='$database' AND `name`='$table' + |""".stripMargin + ) + if (tableOutput.isEmpty) { + actionIfNoSuchTable(database, table) + } + val tableRow = tableOutput.records.head + TableSpec( + database = tableRow.get("database").asText, + name = tableRow.get("name").asText, + uuid = tableRow.get("uuid").asText, + engine = tableRow.get("engine").asText, + is_temporary = tableRow.get("is_temporary").asBoolean, + data_paths = tableRow.get("data_paths").elements().asScala.map(_.asText).toList, + metadata_path = tableRow.get("metadata_path").asText, + metadata_modification_time = LocalDateTime.parse( + tableRow.get("metadata_modification_time").asText, + dateTimeFmt.withZone(tz) + ), + dependencies_database = tableRow.get("dependencies_database").elements().asScala.map(_.asText).toList, + dependencies_table = tableRow.get("dependencies_table").elements().asScala.map(_.asText).toList, + create_table_query = tableRow.get("create_table_query").asText, + engine_full = tableRow.get("engine_full").asText, + partition_key = tableRow.get("partition_key").asText, + sorting_key = tableRow.get("sorting_key").asText, + primary_key = tableRow.get("primary_key").asText, + sampling_key = tableRow.get("sampling_key").asText, + storage_policy = tableRow.get("storage_policy").asText, + total_rows = tableRow.get("total_rows") match { + case _: NullNode | null => None + case node: JsonNode => Some(node.asLong) + }, + total_bytes = tableRow.get("total_bytes") match { + case _: NullNode | null => None + case node: JsonNode => Some(node.asLong) + }, + lifetime_rows = tableRow.get("lifetime_rows") match { + case _: NullNode | null => None + case node: JsonNode => Some(node.asLong) + }, + lifetime_bytes = tableRow.get("lifetime_bytes") match { + case _: NullNode | null => None + case node: JsonNode => Some(node.asLong) + } + ) + } + + def queryTableSchema( + database: String, + table: String, + actionIfNoSuchTable: (String, String) => Unit = DEFAULT_ACTION_IF_NO_SUCH_TABLE + )(implicit nodeClient: NodeClient): StructType = { + val columnOutput = nodeClient.syncQueryAndCheckOutputJSONEachRow( + s"""SELECT + | `database`, -- String + | `table`, -- String + | `name`, -- String + | `type`, -- String + | `position`, -- UInt64 + | `default_kind`, -- String + | `default_expression`, -- String + | `data_compressed_bytes`, -- UInt64 + | `data_uncompressed_bytes`, -- UInt64 + | `marks_bytes`, -- UInt64 + | `comment`, -- String + | `is_in_partition_key`, -- UInt8 + | `is_in_sorting_key`, -- UInt8 + | `is_in_primary_key`, -- UInt8 + | `is_in_sampling_key`, -- UInt8 + | `compression_codec` -- String + |FROM `system`.`columns` + |WHERE `database`='$database' AND `table`='$table' + |ORDER BY `position` ASC + |""".stripMargin + ) + if (columnOutput.isEmpty) { + actionIfNoSuchTable(database, table) + } + SchemaUtils.fromClickHouseSchema(columnOutput.records.map { row => + val fieldName = row.get("name").asText + val ckType = row.get("type").asText + (fieldName, ckType) + }) + } + + def queryPartitionSpec( + database: String, + table: String + )(implicit nodeClient: NodeClient): Seq[PartitionSpec] = { + val partOutput = nodeClient.syncQueryAndCheckOutputJSONEachRow( + s"""SELECT + | partition, -- String + | partition_id, -- String + | sum(rows) AS row_count, -- UInt64 + | sum(bytes_on_disk) AS size_in_bytes -- UInt64 + |FROM `system`.`parts` + |WHERE `database`='$database' AND `table`='$table' AND `active`=1 + |GROUP BY `partition`, `partition_id` + |ORDER BY `partition` ASC, partition_id ASC + |""".stripMargin + ) + if (partOutput.isEmpty || partOutput.rows == 1 && partOutput.records.head.get("partition").asText == "tuple()") { + return Array(NoPartitionSpec) + } + partOutput.records.map { row => + PartitionSpec( + partition_value = row.get("partition").asText, + partition_id = row.get("partition_id").asText, + row_count = row.get("row_count").asLong, + size_in_bytes = row.get("size_in_bytes").asLong + ) + } + } + + /** + * This method is considered as lightweight. Typically `sql` should contains `where 1=0` to avoid running the query on + * ClickHouse. + */ + def getQueryOutputSchema(sql: String)(implicit nodeClient: NodeClient): StructType = { + val namesAndTypes = nodeClient.syncQueryAndCheckOutputJSONCompactEachRowWithNamesAndTypes(sql).namesAndTypes + SchemaUtils.fromClickHouseSchema(namesAndTypes.toSeq) + } + + def dropPartition( + database: String, + table: String, + partitionExpr: String, + cluster: Option[String] = None + )(implicit + nodeClient: NodeClient + ): Boolean = + nodeClient.syncQueryOutputJSONEachRow( + s"ALTER TABLE `$database`.`$table` ${cluster.map(c => s"ON CLUSTER $c").getOrElse("")} DROP PARTITION $partitionExpr" + ) match { + case Right(_) => true + case Left(ex: CHException) => + log.error(s"[${ex.code}]: ${ex.getMessage}") + false + } + + def delete( + database: String, + table: String, + deleteExpr: String, + cluster: Option[String] = None + )(implicit + nodeClient: NodeClient + ): Boolean = + nodeClient.syncQueryOutputJSONEachRow( + s"ALTER TABLE `$database`.`$table` ${cluster.map(c => s"ON CLUSTER $c").getOrElse("")} DELETE WHERE $deleteExpr", + // https://clickhouse.com/docs/en/sql-reference/statements/alter/#synchronicity-of-alter-queries + Map("mutations_sync" -> "2") + ) match { + case Right(_) => true + case Left(ex: CHException) => + log.error(s"[${ex.code}]: ${ex.getMessage}") + false + } + + def truncateTable( + database: String, + table: String, + cluster: Option[String] = None + )(implicit + nodeClient: NodeClient + ): Boolean = nodeClient.syncQueryOutputJSONEachRow( + s"TRUNCATE TABLE `$database`.`$table` ${cluster.map(c => s"ON CLUSTER $c").getOrElse("")}" + ) match { + case Right(_) => true + case Left(ex: CHException) => + log.error(s"[${ex.code}]: ${ex.getMessage}") + false + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseTable.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseTable.scala new file mode 100644 index 00000000..2dec715c --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/ClickHouseTable.scala @@ -0,0 +1,310 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark + +import com.clickhouse.spark.client.NodeClient +import com.clickhouse.spark.expr.{Expr, OrderExpr} +import com.clickhouse.spark.func.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.{READ_DISTRIBUTED_CONVERT_LOCAL, USE_NULLABLE_QUERY_SCHEMA} +import org.apache.spark.sql.clickhouse.{ExprUtils, ReadOptions, WriteOptions} +import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String +import Utils._ +import com.clickhouse.spark.read.{ClickHouseMetadataColumn, ClickHouseScanBuilder, ScanJobDescription} +import com.clickhouse.spark.spec.{ + ClusterSpec, + DistributedEngineSpec, + MergeTreeFamilyEngineSpec, + NodeSpec, + PartitionSpec, + TableEngineSpec, + TableEngineUtils, + TableSpec +} +import com.clickhouse.spark.write.{ClickHouseWriteBuilder, WriteJobDescription} +import com.clickhouse.spark.spec._ + +import java.lang.{Integer => JInt, Long => JLong} +import java.time.{LocalDate, ZoneId} +import java.util +import scala.collection.JavaConverters._ + +case class ClickHouseTable( + node: NodeSpec, + cluster: Option[ClusterSpec], + implicit val tz: ZoneId, + spec: TableSpec, + engineSpec: TableEngineSpec, + functionRegistry: FunctionRegistry +) extends Table + with SupportsRead + with SupportsWrite + with SupportsDelete + with TruncatableTable + with SupportsMetadataColumns + with SupportsPartitionManagement + with ClickHouseHelper + with SQLConfHelper + with SQLHelper + with Logging { + + def database: String = spec.database + + def table: String = spec.name + + def isDistributed: Boolean = engineSpec.is_distributed + + val readDistributedConvertLocal: Boolean = conf.getConf(READ_DISTRIBUTED_CONVERT_LOCAL) + + lazy val (localTableSpec, localTableEngineSpec): (Option[TableSpec], Option[MergeTreeFamilyEngineSpec]) = + engineSpec match { + case distSpec: DistributedEngineSpec => Utils.tryWithResource(NodeClient(node)) { implicit nodeClient => + val _localTableSpec = queryTableSpec(distSpec.local_db, distSpec.local_table) + val _localTableEngineSpec = + TableEngineUtils.resolveTableEngine(_localTableSpec).asInstanceOf[MergeTreeFamilyEngineSpec] + (Some(_localTableSpec), Some(_localTableEngineSpec)) + } + case _ => (None, None) + } + + def shardingKey: Option[Expr] = engineSpec match { + case _spec: DistributedEngineSpec => _spec.sharding_key + case _ => None + } + + def partitionKey: Option[List[Expr]] = engineSpec match { + case mergeTreeFamilySpec: MergeTreeFamilyEngineSpec => Some(mergeTreeFamilySpec.partition_key.exprList) + case _: DistributedEngineSpec => localTableEngineSpec.map(_.partition_key.exprList) + case _: TableEngineSpec => None + } + + def sortingKey: Option[List[OrderExpr]] = engineSpec match { + case mergeTreeFamilySpec: MergeTreeFamilyEngineSpec => Some(mergeTreeFamilySpec.order_by_expr).filter(_.nonEmpty) + case _: DistributedEngineSpec => localTableEngineSpec.map(_.order_by_expr).filter(_.nonEmpty) + case _: TableEngineSpec => None + } + + override def name: String = s"${wrapBackQuote(spec.database)}.${wrapBackQuote(spec.name)}" + + // for SPARK-43390 + def useNullableQuerySchema: Boolean = conf.getConf(USE_NULLABLE_QUERY_SCHEMA) + + override def capabilities(): util.Set[TableCapability] = + Set( + BATCH_READ, + BATCH_WRITE, + TRUNCATE, + ACCEPT_ANY_SCHEMA // TODO check schema and handle extra columns before writing + ).asJava + + override lazy val schema: StructType = Utils.tryWithResource(NodeClient(node)) { implicit nodeClient => + queryTableSchema(database, table) + } + + /** + * Only support `MergeTree` and `Distributed` table engine, for reference + * {{{NamesAndTypesList MergeTreeData::getVirtuals()}}} {{{NamesAndTypesList StorageDistributed::getVirtuals()}}} + */ + override lazy val metadataColumns: Array[MetadataColumn] = { + + def metadataCols(tableEngine: TableEngineSpec): Array[MetadataColumn] = tableEngine match { + case _: MergeTreeFamilyEngineSpec => ClickHouseMetadataColumn.mergeTreeMetadataCols + case _: DistributedEngineSpec => ClickHouseMetadataColumn.distributeMetadataCols + case _ => Array.empty + } + + engineSpec match { + case _: DistributedEngineSpec if readDistributedConvertLocal => metadataCols(localTableEngineSpec.get) + case other: TableEngineSpec => metadataCols(other) + } + } + + private lazy val metadataSchema: StructType = + StructType(metadataColumns.map(_.asInstanceOf[ClickHouseMetadataColumn].toStructField)) + + override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey, functionRegistry) + + override lazy val partitionSchema: StructType = StructType( + partitioning.map { partTransform => + ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform, functionRegistry) + } + ) + + override lazy val properties: util.Map[String, String] = spec.toJavaMap + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + val scanJob = read.ScanJobDescription( + node = node, + tz = tz, + tableSpec = spec, + tableEngineSpec = engineSpec, + cluster = cluster, + localTableSpec = localTableSpec, + localTableEngineSpec = localTableEngineSpec, + readOptions = new ReadOptions(options.asCaseSensitiveMap()) + ) + // TODO schema of partitions + val partTransforms = Array[Transform]() + new ClickHouseScanBuilder(scanJob, schema, metadataSchema, partTransforms) + } + + override def newWriteBuilder(info: LogicalWriteInfo): ClickHouseWriteBuilder = { + val writeJob = write.WriteJobDescription( + queryId = info.queryId, + tableSchema = schema, + metadataSchema = metadataSchema, + dataSetSchema = info.schema, + node = node, + tz = tz, + tableSpec = spec, + tableEngineSpec = engineSpec, + cluster = cluster, + localTableSpec = localTableSpec, + localTableEngineSpec = localTableEngineSpec, + shardingKey = shardingKey, + partitionKey = partitionKey, + sortingKey = sortingKey, + writeOptions = new WriteOptions(info.options.asCaseSensitiveMap()), + functionRegistry = functionRegistry + ) + + new ClickHouseWriteBuilder(writeJob) + } + + override def createPartition(ident: InternalRow, props: util.Map[String, String]): Unit = + log.info("Do nothing on ClickHouse for creating partition action") + + override def dropPartition(ident: InternalRow): Boolean = { + val partitionExpr = (0 until ident.numFields).map { i => + partitionSchema.fields(i).dataType match { + case IntegerType => compileValue(ident.getInt(i)) + case LongType => compileValue(ident.getLong(i)) + case StringType => compileValue(ident.getUTF8String(i)) + case DateType => compileValue(LocalDate.ofEpochDay(ident.getInt(i))) + case illegal => throw new IllegalArgumentException(s"Illegal partition data type: $illegal") + } + }.mkString("(", ",", ")") + + Utils.tryWithResource(NodeClient(node)) { implicit nodeClient => + engineSpec match { + case DistributedEngineSpec(_, cluster, local_db, local_table, _, _) => + dropPartition(local_db, local_table, partitionExpr, Some(cluster)) + case _ => + dropPartition(database, table, partitionExpr) + } + } + } + + override def purgePartition(ident: InternalRow): Boolean = dropPartition(ident) + + override def truncatePartition(ident: InternalRow): Boolean = dropPartition(ident) + + override def replacePartitionMetadata(ident: InternalRow, props: util.Map[String, String]): Unit = + throw new UnsupportedOperationException("Unsupported operation: replacePartitionMetadata") + + override def loadPartitionMetadata(ident: InternalRow): util.Map[String, String] = + throw new UnsupportedOperationException("Unsupported operation: loadPartitionMetadata") + + override def listPartitionIdentifiers(names: Array[String], ident: InternalRow): Array[InternalRow] = { + assert( + names.length == ident.numFields, + s"Number of partition names (${names.length}) must be equal to " + + s"the number of partition values (${ident.numFields})." + ) + assert( + names.forall(fieldName => partitionSchema.fieldNames.contains(fieldName)), + s"Some partition names ${names.mkString("[", ", ", "]")} don't belong to " + + s"the partition schema '${partitionSchema.sql}'." + ) + + def strToSparkValue(str: String, dataType: DataType): Any = dataType match { + case StringType => UTF8String.fromString(str.stripPrefix("'").stripSuffix("'")) + case IntegerType => JInt.parseInt(str) + case LongType => JLong.parseLong(str) + case DateType => LocalDate.parse(str.stripPrefix("'").stripSuffix("'"), dateFmt).toEpochDay.toInt + case unsupported => throw new UnsupportedOperationException(s"$unsupported") + } + + val partitionSpecs: Seq[PartitionSpec] = engineSpec match { + case DistributedEngineSpec(_, _, local_db, local_table, _, _) => + cluster.get.shards.flatMap { shardSpec => + Utils.tryWithResource(NodeClient(shardSpec.nodes.head)) { implicit nodeClient: NodeClient => + queryPartitionSpec(local_db, local_table) + } + } + case _ => + Utils.tryWithResource(NodeClient(node)) { implicit nodeClient => + queryPartitionSpec(database, table) + } + } + partitionSpecs.map(_.partition_value) + .distinct + .filterNot(_.isEmpty) // represent partitioned table w/o records + .filterNot(_ == "tuple()") // represent the root partition of un-partitioned table + .map { + case tuple if tuple.startsWith("(") && tuple.endsWith(")") => + tuple.stripPrefix("(").stripSuffix(")").split(",") + case partColStrValue => + Array(partColStrValue) + } + .map { partColStrValues => + new GenericInternalRow( + (partColStrValues zip partitionSchema.fields.map(_.dataType)) + .map { case (partColStrValue, dataType) => strToSparkValue(partColStrValue, dataType) } + ) + } + .filter { partRow => + names.zipWithIndex.forall { case (name, queryIndex) => + val partRowIndex = partitionSchema.fieldIndex(name) + val dataType = partitionSchema.fields(partRowIndex).dataType + partRow.get(partRowIndex, dataType) == ident.get(queryIndex, dataType) + } + } + .toArray + } + + override def canDeleteWhere(filters: Array[Filter]): Boolean = filters.forall(f => compileFilter(f).isDefined) + + override def deleteWhere(filters: Array[Filter]): Unit = { + val deleteExpr = compileFilters(AlwaysTrue :: filters.toList) + Utils.tryWithResource(NodeClient(node)) { implicit nodeClient => + engineSpec match { + case DistributedEngineSpec(_, cluster, local_db, local_table, _, _) => + delete(local_db, local_table, deleteExpr, Some(cluster)) + case _ => + delete(database, table, deleteExpr) + } + } + } + + override def truncateTable(): Boolean = + Utils.tryWithResource(NodeClient(node)) { implicit nodeClient => + engineSpec match { + case DistributedEngineSpec(_, cluster, local_db, local_table, _, _) => + truncateTable(local_db, local_table, Some(cluster)) + case _ => + truncateTable(database, table) + } + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/CommitMessage.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/CommitMessage.scala new file mode 100644 index 00000000..c98f4c9d --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/CommitMessage.scala @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark + +import org.apache.spark.sql.connector.write.WriterCommitMessage + +case class CommitMessage(msg: String = "") extends WriterCommitMessage diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/Constants.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/Constants.scala new file mode 100644 index 00000000..94a14754 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/Constants.scala @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark + +import com.clickhouse.client.config.ClickHouseClientOption._ + +object Constants { + // format: off + ////////////////////////////////////////////////////////// + //////// clickhouse datasource catalog properties //////// + ////////////////////////////////////////////////////////// + final val CATALOG_PROP_HOST = "host" + final val CATALOG_PROP_TCP_PORT = "tcp_port" + final val CATALOG_PROP_HTTP_PORT = "http_port" + final val CATALOG_PROP_PROTOCOL = "protocol" + final val CATALOG_PROP_USER = "user" + final val CATALOG_PROP_PASSWORD = "password" + final val CATALOG_PROP_DATABASE = "database" + final val CATALOG_PROP_TZ = "timezone" // server(default), client, UTC+3, Asia/Shanghai, etc. + final val CATALOG_INFER_RUNTIME_ENV = "infer_runtime_env" + final val CATALOG_PROP_OPTION_PREFIX = "option." + final val CATALOG_PROP_IGNORE_OPTIONS = Seq( + DATABASE.getKey, COMPRESS.getKey, DECOMPRESS.getKey, FORMAT.getKey, RETRY.getKey, + USE_SERVER_TIME_ZONE.getKey, USE_SERVER_TIME_ZONE_FOR_DATES.getKey, SERVER_TIME_ZONE.getKey, USE_TIME_ZONE.getKey, + CATALOG_INFER_RUNTIME_ENV) + + ////////////////////////////////////////////////////////// + ////////// clickhouse datasource read properties ///////// + ////////////////////////////////////////////////////////// + + ////////////////////////////////////////////////////////// + ///////// clickhouse datasource write properties ///////// + ////////////////////////////////////////////////////////// + // format: on +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/Metrics.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/Metrics.scala new file mode 100644 index 00000000..e7b84f24 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/Metrics.scala @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark + +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric} +import Metrics._ + +case class TaskMetric(override val name: String, override val value: Long) extends CustomTaskMetric + +abstract class SizeSumMetric extends CustomMetric { + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = Utils.bytesToString(taskMetrics.sum) +} + +abstract class DurationSumMetric extends CustomMetric { + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = Utils.msDurationToString(taskMetrics.sum) +} + +object Metrics { + val BLOCKS_READ = "blocksRead" + val BYTES_READ = "bytesRead" + + val RECORDS_WRITTEN = "recordsWritten" + val BYTES_WRITTEN = "bytesWritten" + val SERIALIZE_TIME = "serializeTime" + val WRITE_TIME = "writeTime" +} + +case class BlocksReadMetric() extends CustomSumMetric { + override def name: String = BLOCKS_READ + override def description: String = "number of blocks" +} + +case class BytesReadMetric() extends SizeSumMetric { + override def name: String = BYTES_READ + override def description: String = "data size" +} + +case class RecordsWrittenMetric() extends CustomSumMetric { + override def name: String = RECORDS_WRITTEN + override def description: String = "number of output rows" +} + +case class BytesWrittenMetric() extends SizeSumMetric { + override def name: String = BYTES_WRITTEN + override def description: String = "written output" +} + +case class SerializeTimeMetric() extends DurationSumMetric { + override def name: String = SERIALIZE_TIME + override def description: String = "total time of serialization" +} + +case class WriteTimeMetric() extends DurationSumMetric { + override def name: String = WRITE_TIME + override def description: String = "total time of writing" +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/SQLHelper.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/SQLHelper.scala new file mode 100644 index 00000000..6531f5fa --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/SQLHelper.scala @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark + +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate, LocalDateTime, ZoneId} +import org.apache.commons.lang3.StringUtils +import org.apache.spark.sql.connector.expressions.aggregate._ +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.sources._ +import org.apache.spark.unsafe.types.UTF8String +import Utils._ + +trait SQLHelper { + + def quoted(token: String) = s"`$token`" + + // null => null, ' => '' + def escapeSql(value: String): String = StringUtils.replace(value, "'", "''") + + def compileValue(value: Any)(implicit tz: ZoneId): Any = value match { + case string: String => s"'${escapeSql(string)}'" + case utf8: UTF8String => s"'${escapeSql(utf8.toString)}'" + case instant: Instant => s"'${dateTimeFmt.withZone(tz).format(instant)}'" + case timestamp: Timestamp => s"'${legacyDateTimeFmt.format(timestamp)}'" + case localDateTime: LocalDateTime => s"'${dateTimeFmt.format(localDateTime)}'" + case legacyDate: Date => s"'${legacyDateFmt.format(legacyDate)}'" + case localDate: LocalDate => s"'${dateFmt.format(localDate)}'" + case array: Array[Any] => array.map(compileValue).mkString(",") + case _ => value + } + + def compileFilter(f: Filter)(implicit tz: ZoneId): Option[String] = Option(f match { + case AlwaysTrue => "1=1" + case AlwaysFalse => "1=0" + case EqualTo(attr, value) => s"${quoted(attr)} = ${compileValue(value)}" + case EqualNullSafe(attr, nullableValue) => + val (col, value) = (quoted(attr), compileValue(nullableValue)) + s"(NOT ($col != $value OR $col IS NULL OR $value IS NULL) OR ($col IS NULL AND $value IS NULL))" + case LessThan(attr, value) => s"${quoted(attr)} < ${compileValue(value)}" + case GreaterThan(attr, value) => s"${quoted(attr)} > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"${quoted(attr)} <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"${quoted(attr)} >= ${compileValue(value)}" + case IsNull(attr) => s"${quoted(attr)} IS NULL" + case IsNotNull(attr) => s"${quoted(attr)} IS NOT NULL" + case StringStartsWith(attr, value) => s"${quoted(attr)} LIKE '$value%'" + case StringEndsWith(attr, value) => s"${quoted(attr)} LIKE '%$value'" + case StringContains(attr, value) => s"${quoted(attr)} LIKE '%$value%'" + case In(attr, value) if value.isEmpty => s"CASE WHEN ${quoted(attr)} IS NULL THEN NULL ELSE FALSE END" + case In(attr, value) => s"${quoted(attr)} IN (${compileValue(value)})" + case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").orNull + case Or(f1, f2) => + val or = Seq(f1, f2).flatMap(_f => compileFilter(_f)(tz)) + if (or.size == 2) or.map(p => s"($p)").mkString(" OR ") else null + case And(f1, f2) => + val and = Seq(f1, f2).flatMap(_f => compileFilter(_f)(tz)) + if (and.size == 2) and.map(p => s"($p)").mkString(" AND ") else null + case _ => null + }) + + def compileAggregate(aggFunction: AggregateFunc): Option[String] = + aggFunction match { + case min: Min if min.column.isInstanceOf[NamedReference] => + val col = min.column.asInstanceOf[NamedReference] + if (col.fieldNames().length != 1) return None + Some(s"MIN(${quoted(col.fieldNames.head)})") + case max: Max if max.column.isInstanceOf[NamedReference] => + val col = max.column.asInstanceOf[NamedReference] + if (col.fieldNames.length != 1) return None + Some(s"MAX(${quoted(col.fieldNames.head)})") + case count: Count if count.column.isInstanceOf[NamedReference] => + val col = count.column.asInstanceOf[NamedReference] + if (col.fieldNames.length != 1) return None + val distinct = if (count.isDistinct) "DISTINCT " else "" + val column = quoted(col.fieldNames.head) + Some(s"COUNT($distinct$column)") + case sum: Sum if sum.column.isInstanceOf[NamedReference] => + val col = sum.column.asInstanceOf[NamedReference] + if (col.fieldNames.length != 1) return None + val distinct = if (sum.isDistinct) "DISTINCT " else "" + val column = quoted(col.fieldNames.head) + Some(s"SUM($distinct$column)") + case _: CountStar => + Some("COUNT(*)") + case _ => None + } + + def compileFilters(filters: Seq[Filter])(implicit tz: ZoneId): String = + filters + .flatMap(_f => compileFilter(_f)(tz)) + .map(p => s"($p)").mkString(" AND ") +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/CityHash64.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/CityHash64.scala new file mode 100644 index 00000000..5cd677d1 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/CityHash64.scala @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.func + +import com.clickhouse.spark.hash + +// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L694 +object CityHash64 extends MultiStringArgsHash { + + override protected def funcName: String = "clickhouse_cityHash64" + + override val ckFuncNames: Array[String] = Array("cityHash64") + + override def applyHash(input: Array[Any]): Long = hash.CityHash64(input) +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/FunctionRegistry.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/FunctionRegistry.scala new file mode 100644 index 00000000..25263717 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/FunctionRegistry.scala @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.func + +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction + +import scala.collection.mutable + +trait FunctionRegistry extends Serializable { + + def list: Array[String] + + def load(name: String): Option[UnboundFunction] + + def sparkToClickHouseFunc: Map[String, String] + + def clickHouseToSparkFunc: Map[String, String] +} + +trait ClickhouseEquivFunction { + val ckFuncNames: Array[String] +} + +class CompositeFunctionRegistry(registries: Array[FunctionRegistry]) extends FunctionRegistry { + + override def list: Array[String] = registries.flatMap(_.list) + + override def load(name: String): Option[UnboundFunction] = registries.flatMap(_.load(name)).headOption + + override def sparkToClickHouseFunc: Map[String, String] = registries.flatMap(_.sparkToClickHouseFunc).toMap + + override def clickHouseToSparkFunc: Map[String, String] = registries.flatMap(_.clickHouseToSparkFunc).toMap +} + +object StaticFunctionRegistry extends FunctionRegistry { + + private val functions = Map[String, UnboundFunction]( + "ck_xx_hash64" -> ClickHouseXxHash64, // for compatible + "clickhouse_xxHash64" -> ClickHouseXxHash64, + "clickhouse_murmurHash2_32" -> MurmurHash2_32, + "clickhouse_murmurHash2_64" -> MurmurHash2_64, + "clickhouse_murmurHash3_32" -> MurmurHash3_32, + "clickhouse_murmurHash3_64" -> MurmurHash3_64, + "clickhouse_cityHash64" -> CityHash64 + ) + + override def list: Array[String] = functions.keys.toArray + + override def load(name: String): Option[UnboundFunction] = functions.get(name) + + override val sparkToClickHouseFunc: Map[String, String] = + functions.filter(_._2.isInstanceOf[ClickhouseEquivFunction]).flatMap { case (k, v) => + v.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.map((k, _)) + } + + override val clickHouseToSparkFunc: Map[String, String] = + functions.filter(_._2.isInstanceOf[ClickhouseEquivFunction]).flatMap { case (k, v) => + v.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.map((_, k)) + } +} + +class DynamicFunctionRegistry extends FunctionRegistry { + + private val functions = mutable.Map[String, UnboundFunction]() + + def register(name: String, function: UnboundFunction): DynamicFunctionRegistry = { + functions += (name -> function) + this + } + + override def list: Array[String] = functions.keys.toArray + + override def load(name: String): Option[UnboundFunction] = functions.get(name) + + override def sparkToClickHouseFunc: Map[String, String] = + functions.filter(_._2.isInstanceOf[ClickhouseEquivFunction]).toMap.flatMap { case (k, v) => + v.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.map((k, _)) + } + + override def clickHouseToSparkFunc: Map[String, String] = + functions.filter(_._2.isInstanceOf[ClickhouseEquivFunction]).toMap.flatMap { case (k, v) => + v.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.map((_, k)) + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MultiStringArgsHash.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MultiStringArgsHash.scala new file mode 100644 index 00000000..d68fb2ea --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MultiStringArgsHash.scala @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.func + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +abstract class MultiStringArgsHash extends UnboundFunction with ClickhouseEquivFunction { + + def applyHash(input: Array[Any]): Long + + protected def funcName: String + + override val ckFuncNames: Array[String] + + override def description: String = s"$name: (value: string, ...) => hash_value: long" + + private def isExceptedType(dt: DataType): Boolean = + dt.isInstanceOf[StringType] + + final override def name: String = funcName + + final override def bind(inputType: StructType): BoundFunction = { + val inputDataTypes = inputType.fields.map(_.dataType) + if (inputDataTypes.forall(isExceptedType)) { + // need to new a ScalarFunction instance for each bind, + // because we do not know the number of arguments in advance + new ScalarFunction[Long] { + override def inputTypes(): Array[DataType] = inputDataTypes + override def name: String = funcName + override def canonicalName: String = s"clickhouse.$name" + override def resultType: DataType = LongType + override def toString: String = name + override def produceResult(input: InternalRow): Long = { + val inputStrings = new Array[Any](input.numFields) + var i = 0 + do { + inputStrings(i) = input.getUTF8String(i).getBytes + i += 1 + } while (i < input.numFields) + applyHash(inputStrings) + } + } + } else { + throw new UnsupportedOperationException(s"Expect multiple STRING argument. $description") + } + } + +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MurmurHash2.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MurmurHash2.scala new file mode 100644 index 00000000..e6791e4e --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MurmurHash2.scala @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.func + +import com.clickhouse.spark.hash.{HashUtils, Murmurhash2_32, Murmurhash2_64} +import com.clickhouse.spark.hash + +// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L460 +object MurmurHash2_64 extends MultiStringArgsHash { + + override protected def funcName: String = "clickhouse_murmurHash2_64" + + override val ckFuncNames: Array[String] = Array("murmurHash2_64") + + override def applyHash(input: Array[Any]): Long = Murmurhash2_64(input) +} + +// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L519 +object MurmurHash2_32 extends MultiStringArgsHash { + + override protected def funcName: String = "clickhouse_murmurHash2_32" + + override val ckFuncNames: Array[String] = Array("murmurHash2_32") + + override def applyHash(input: Array[Any]): Long = HashUtils.toUInt32(Murmurhash2_32(input)) +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MurmurHash3.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MurmurHash3.scala new file mode 100644 index 00000000..a9dc2ba9 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/MurmurHash3.scala @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.func + +import com.clickhouse.spark.hash.{HashUtils, Murmurhash3_32, Murmurhash3_64} +import com.clickhouse.spark.hash + +// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L543 +object MurmurHash3_64 extends MultiStringArgsHash { + + override protected def funcName: String = "clickhouse_murmurHash3_64" + + override val ckFuncNames: Array[String] = Array("murmurHash3_64") + + override def applyHash(input: Array[Any]): Long = Murmurhash3_64(input) +} + +// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L519 +object MurmurHash3_32 extends MultiStringArgsHash { + + override protected def funcName: String = "clickhouse_murmurHash3_32" + + override val ckFuncNames: Array[String] = Array("murmurHash3_32") + + override def applyHash(input: Array[Any]): Long = HashUtils.toUInt32(Murmurhash3_32(input)) +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/XxHash64.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/XxHash64.scala new file mode 100644 index 00000000..7e2b5287 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/func/XxHash64.scala @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.func + +import com.clickhouse.spark.spec.{ClusterSpec, ShardUtils} +import org.apache.spark.sql.catalyst.expressions.XxHash64Function +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * ClickHouse equivalent function: + * {{{ + * select xxHash64(concat(project_id, toString(seq)) + * }}} + */ +object ClickHouseXxHash64 extends UnboundFunction with ScalarFunction[Long] with ClickhouseEquivFunction { + + override def name: String = "clickhouse_xxHash64" + + override def canonicalName: String = s"clickhouse.$name" + + override def toString: String = name + + override val ckFuncNames: Array[String] = Array("xxHash64") + + override def description: String = s"$name: (value: string) => hash_value: long" + + override def bind(inputType: StructType): BoundFunction = inputType.fields match { + case Array(StructField(_, StringType, _, _)) => this + case _ => throw new UnsupportedOperationException(s"Expect 1 STRING argument. $description") + } + + override def inputTypes: Array[DataType] = Array(StringType) + + override def resultType: DataType = LongType + + override def isResultNullable: Boolean = false + + // ignore UInt64 vs Int64 + def invoke(value: UTF8String): Long = XxHash64Function.hash(value, StringType, 0L) +} + +/** + * Create ClickHouse table with DDL: + * {{{ + * CREATE TABLE ON CLUSTER cluster ( + * ... + * ) ENGINE = Distributed( + * cluster, + * db, + * local_table, + * xxHash64(concat(project_id, project_version, toString(seq)) + * ); + * }}} + */ +class ClickHouseXxHash64Shard(clusters: Seq[ClusterSpec]) extends UnboundFunction with ScalarFunction[Int] { + + @transient private lazy val indexedClusters = + clusters.map(cluster => UTF8String.fromString(cluster.name) -> cluster).toMap + + override def name: String = "clickhouse_shard_xxHash64" + + override def canonicalName: String = s"clickhouse.$name" + + override def description: String = s"$name: (cluster_name: string, value: string) => shard_num: int" + + override def bind(inputType: StructType): BoundFunction = inputType.fields match { + case Array(StructField(_, StringType, _, _), StructField(_, StringType, _, _)) => this + case _ => throw new UnsupportedOperationException(s"Expect 2 STRING argument. $description") + } + + override def inputTypes: Array[DataType] = Array(StringType, StringType) + + override def resultType: DataType = IntegerType + + override def isResultNullable: Boolean = false + + def invoke(clusterName: UTF8String, value: UTF8String): Int = { + val clusterSpec = + indexedClusters.getOrElse(clusterName, throw new RuntimeException(s"Unknown cluster: $clusterName")) + val hashVal = XxHash64Function.hash(value, StringType, 0L) + ShardUtils.calcShard(clusterSpec, hashVal).num + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseMetadataColumn.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseMetadataColumn.scala new file mode 100644 index 00000000..9bc518e1 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseMetadataColumn.scala @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.read + +import org.apache.spark.sql.connector.catalog.MetadataColumn +import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType, StructField} + +object ClickHouseMetadataColumn { + val mergeTreeMetadataCols: Array[MetadataColumn] = Array( + ClickHouseMetadataColumn("_part", StringType), + ClickHouseMetadataColumn("_part_index", LongType), + ClickHouseMetadataColumn("_part_uuid", StringType), + ClickHouseMetadataColumn("_partition_id", StringType), + // ClickHouseMetadataColumn("_partition_value", StringType), + ClickHouseMetadataColumn("_sample_factor", DoubleType) + ) + + val distributeMetadataCols: Array[MetadataColumn] = Array( + ClickHouseMetadataColumn("_table", StringType), + ClickHouseMetadataColumn("_part", StringType), + ClickHouseMetadataColumn("_part_index", LongType), + ClickHouseMetadataColumn("_part_uuid", StringType), + ClickHouseMetadataColumn("_partition_id", StringType), + ClickHouseMetadataColumn("_sample_factor", DoubleType), + ClickHouseMetadataColumn("_shard_num", IntegerType) + ) +} + +case class ClickHouseMetadataColumn( + override val name: String, + override val dataType: DataType, + override val isNullable: Boolean = false +) extends MetadataColumn { + def toStructField: StructField = StructField(name, dataType, isNullable) +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseRead.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseRead.scala new file mode 100644 index 00000000..d210fb0f --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseRead.scala @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.read + +import com.clickhouse.spark.client.NodeClient +import com.clickhouse.spark.exception.CHClientException +import com.clickhouse.spark.read.format.{ClickHouseBinaryReader, ClickHouseJsonReader} +import com.clickhouse.spark.spec.{DistributedEngineSpec, NoPartitionSpec, TableEngineSpec} +import com.clickhouse.spark.{BlocksReadMetric, BytesReadMetric, ClickHouseHelper, Logging, SQLHelper, Utils} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf._ +import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.metric.CustomMetric +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.connector.read.partitioning.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.types.StructType +import com.clickhouse.spark._ +import com.clickhouse.spark.spec._ + +import java.time.ZoneId +import scala.util.control.NonFatal + +class ClickHouseScanBuilder( + scanJob: ScanJobDescription, + physicalSchema: StructType, + metadataSchema: StructType, + partitionTransforms: Array[Transform] +) extends ScanBuilder + with SupportsPushDownLimit + with SupportsPushDownFilters + with SupportsPushDownAggregates + with SupportsPushDownRequiredColumns + with ClickHouseHelper + with SQLHelper + with Logging { + + implicit private val tz: ZoneId = scanJob.tz + + private val reservedMetadataSchema: StructType = StructType( + metadataSchema.dropWhile(field => physicalSchema.fields.map(_.name).contains(field.name)) + ) + + private var _readSchema: StructType = StructType( + physicalSchema.fields ++ reservedMetadataSchema.fields + ) + + private var _limit: Option[Int] = None + + override def pushLimit(limit: Int): Boolean = { + this._limit = Some(limit) + true + } + + private var _pushedFilters = Array.empty[Filter] + + override def pushedFilters: Array[Filter] = this._pushedFilters + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (pushed, unSupported) = filters.partition(f => compileFilter(f).isDefined) + this._pushedFilters = pushed + unSupported + } + + private var _pushedGroupByCols: Option[Array[String]] = None + private var _groupByClause: Option[String] = None + + override def pushAggregation(aggregation: Aggregation): Boolean = { + val compiledAggs = aggregation.aggregateExpressions.flatMap(compileAggregate) + if (compiledAggs.length != aggregation.aggregateExpressions.length) return false + + val compiledGroupByCols = aggregation.groupByExpressions.map(_.toString) + + // The column names here are already quoted and can be used to build sql string directly. + // e.g. [`DEPT`, `NAME`, MAX(`SALARY`), MIN(`BONUS`)] => + // SELECT `DEPT`, `NAME`, MAX(`SALARY`), MIN(`BONUS`) + // FROM `test`.`employee` + // WHERE 1=0 + // GROUP BY `DEPT`, `NAME` + val compiledSelectItems = compiledGroupByCols ++ compiledAggs + val groupByClause = if (compiledGroupByCols.nonEmpty) "GROUP BY " + compiledGroupByCols.mkString(", ") else "" + val aggQuery = + s"""SELECT ${compiledSelectItems.mkString(", ")} + |FROM ${quoted(scanJob.tableSpec.database)}.${quoted(scanJob.tableSpec.name)} + |WHERE 1=0 + |$groupByClause + |""".stripMargin + try { + _readSchema = Utils.tryWithResource(NodeClient(scanJob.node)) { implicit nodeClient: NodeClient => + val fields = (getQueryOutputSchema(aggQuery) zip compiledSelectItems) + .map { case (structField, colExpr) => structField.copy(name = colExpr) } + StructType(fields) + } + _pushedGroupByCols = Some(compiledGroupByCols) + _groupByClause = Some(groupByClause) + true + } catch { + case NonFatal(e) => + log.error("Failed to push down aggregation to ClickHouse", e) + false + } + } + + override def pruneColumns(requiredSchema: StructType): Unit = { + val requiredCols = requiredSchema.map(_.name) + this._readSchema = StructType(_readSchema.filter(field => requiredCols.contains(field.name))) + } + + override def build(): Scan = new ClickHouseBatchScan(scanJob.copy( + readSchema = _readSchema, + filtersExpr = compileFilters(AlwaysTrue :: pushedFilters.toList), + groupByClause = _groupByClause, + limit = _limit + )) +} + +class ClickHouseBatchScan(scanJob: ScanJobDescription) extends Scan with Batch + with SupportsReportPartitioning + with SupportsRuntimeFiltering + with PartitionReaderFactory + with ClickHouseHelper + with SQLHelper { + + implicit private val tz: ZoneId = scanJob.tz + + private var runtimeFilters: Array[Filter] = Array.empty + + val database: String = scanJob.database + val table: String = scanJob.table + + lazy val inputPartitions: Array[ClickHouseInputPartition] = scanJob.tableEngineSpec match { + case DistributedEngineSpec(_, _, local_db, local_table, _, _) if scanJob.readOptions.convertDistributedToLocal => + scanJob.cluster.get.shards.flatMap { shardSpec => + Utils.tryWithResource(NodeClient(shardSpec.nodes.head)) { implicit nodeClient: NodeClient => + queryPartitionSpec(local_db, local_table).map { partitionSpec => + ClickHouseInputPartition( + scanJob.localTableSpec.get, + partitionSpec, + scanJob.readOptions.splitByPartitionId, + shardSpec // TODO pickup preferred + ) + } + } + } + case _: DistributedEngineSpec if scanJob.readOptions.useClusterNodesForDistributed => + throw CHClientException( + s"${READ_DISTRIBUTED_USE_CLUSTER_NODES.key} is not supported yet." + ) + case _: DistributedEngineSpec => + // we can not collect all partitions from single node, thus should treat table as no partitioned table + Array(ClickHouseInputPartition( + scanJob.tableSpec, + NoPartitionSpec, + scanJob.readOptions.splitByPartitionId, + scanJob.node + )) + case _: TableEngineSpec => + Utils.tryWithResource(NodeClient(scanJob.node)) { implicit nodeClient: NodeClient => + queryPartitionSpec(database, table).map { partitionSpec => + ClickHouseInputPartition( + scanJob.tableSpec, + partitionSpec, + scanJob.readOptions.splitByPartitionId, + scanJob.node // TODO pickup preferred + ) + } + }.toArray + } + + override def toBatch: Batch = this + + // may contains meta columns + override def readSchema(): StructType = scanJob.readSchema + + override def planInputPartitions: Array[InputPartition] = inputPartitions.toArray + + // TODO KeyGroupedPartitioning + override def outputPartitioning(): Partitioning = new UnknownPartitioning(inputPartitions.length) + + override def createReaderFactory: PartitionReaderFactory = this + + override def createReader(_partition: InputPartition): PartitionReader[InternalRow] = { + val format = scanJob.readOptions.format + val partition = _partition.asInstanceOf[ClickHouseInputPartition] + val finalScanJob = scanJob.copy(filtersExpr = + scanJob.filtersExpr + " AND " + + compileFilters(AlwaysTrue :: runtimeFilters.toList) + ) + format match { + case "json" => new ClickHouseJsonReader(finalScanJob, partition) + case "binary" => new ClickHouseBinaryReader(finalScanJob, partition) + case unsupported => throw CHClientException(s"Unsupported read format: $unsupported") + } + } + + override def supportedCustomMetrics(): Array[CustomMetric] = Array( + BlocksReadMetric(), + BytesReadMetric() + ) + + override def filterAttributes(): Array[NamedReference] = + if (scanJob.readOptions.runtimeFilterEnabled) { + scanJob.readSchema.fields.map(field => Expressions.column(field.name)) + } else { + Array.empty + } + + override def filter(filters: Array[Filter]): Unit = + runtimeFilters = filters +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseReader.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseReader.scala new file mode 100644 index 00000000..53246f1b --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ClickHouseReader.scala @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.read + +import com.clickhouse.spark.{ClickHouseHelper, Logging, TaskMetric} +import com.clickhouse.spark.client.{NodeClient, NodesClient} +import com.clickhouse.data.ClickHouseCompression +import com.clickhouse.spark.format.StreamOutput +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf._ +import org.apache.spark.sql.connector.metric.CustomTaskMetric +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types._ +import com.clickhouse.spark.Metrics.{BLOCKS_READ, BYTES_READ} +import com.clickhouse.client.ClickHouseResponse +import com.clickhouse.client.api.query.QueryResponse + +abstract class ClickHouseReader[Record]( + scanJob: ScanJobDescription, + part: ClickHouseInputPartition +) extends PartitionReader[InternalRow] + with ClickHouseHelper + with SQLConfHelper + with Logging { + + val readDistributedUseClusterNodes: Boolean = conf.getConf(READ_DISTRIBUTED_USE_CLUSTER_NODES) + val readDistributedConvertLocal: Boolean = conf.getConf(READ_DISTRIBUTED_CONVERT_LOCAL) + private val readSettings: Option[String] = conf.getConf(READ_SETTINGS) + + val database: String = part.table.database + val table: String = part.table.name + val readSchema: StructType = scanJob.readSchema + + private lazy val nodesClient = NodesClient(part.candidateNodes) + + def nodeClient: NodeClient = nodesClient.node + + lazy val scanQuery: String = { + val selectItems = + if (readSchema.isEmpty) { + "1" // for case like COUNT(*) which prunes all columns + } else { + readSchema.map { + field => if (scanJob.groupByClause.isDefined) field.name else s"`${field.name}`" + }.mkString(", ") + } + s"""SELECT $selectItems + |FROM `$database`.`$table` + |WHERE (${part.partFilterExpr}) AND (${scanJob.filtersExpr}) + |${scanJob.groupByClause.getOrElse("")} + |${scanJob.limit.map(n => s"LIMIT $n").getOrElse("")} + |${readSettings.map(settings => s"SETTINGS $settings").getOrElse("")} + |""".stripMargin + } + + def format: String + + lazy val resp: QueryResponse = nodeClient.queryAndCheck(scanQuery, format) + + def totalBlocksRead: Long = 0L // resp.getSummary.getStatistics.getBlocks + + def totalBytesRead: Long = resp.getReadBytes // resp.getSummary.getReadBytes + + override def currentMetricsValues: Array[CustomTaskMetric] = Array( + TaskMetric(BLOCKS_READ, totalBlocksRead), + TaskMetric(BYTES_READ, totalBytesRead) + ) + + def streamOutput: Iterator[Record] + + private var currentRecord: Record = _ + + override def next(): Boolean = { + val hasNext = streamOutput.hasNext + if (hasNext) currentRecord = streamOutput.next + hasNext + } + + override def get: InternalRow = decode(currentRecord) + + def decode(record: Record): InternalRow + + override def close(): Unit = nodesClient.close() +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/InputPartitions.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/InputPartitions.scala new file mode 100644 index 00000000..13ed9744 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/InputPartitions.scala @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.read + +import com.clickhouse.spark.spec.{NoPartitionSpec, NodeSpec, Nodes, PartitionSpec, TableSpec} +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.partitioning.Partitioning +import com.clickhouse.spark.spec._ + +case class ClickHousePartitioning(inputParts: Array[ClickHouseInputPartition]) extends Partitioning { + + override def numPartitions(): Int = inputParts.length + +} + +case class ClickHouseInputPartition( + table: TableSpec, + partition: PartitionSpec, + filterByPartitionId: Boolean, + candidateNodes: Nodes, // try to use them only when preferredNode unavailable + preferredNode: Option[NodeSpec] = None // TODO assigned by ScanBuilder in Spark Driver side +) extends InputPartition { + + override def preferredLocations(): Array[String] = preferredNode match { + case Some(preferred) => Array(preferred.host) + case None => candidateNodes.nodes.map(_.host) + } + + def partFilterExpr: String = partition match { + case NoPartitionSpec => "1=1" + case PartitionSpec(_, partitionId, _, _) if filterByPartitionId => + s"_partition_id = '$partitionId'" + case PartitionSpec(partitionValue, _, _, _) => + s"${table.partition_key} = ${compilePartitionFilterValue(partitionValue)}" + } + + // TODO improve and test + def compilePartitionFilterValue(partitionValue: String): String = + (partitionValue.contains("-"), partitionValue.contains("(")) match { + // quote when partition by a single Date Type column to avoid illegal types of arguments (Date, Int64) + case (true, false) => s"'$partitionValue'" + // Date type column is quoted if there are multi partition columns + case _ => s"$partitionValue" + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ScanJobDescription.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ScanJobDescription.scala new file mode 100644 index 00000000..0b4c8bcb --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/ScanJobDescription.scala @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.read + +import com.clickhouse.spark.spec.{ClusterSpec, DistributedEngineSpec, NodeSpec, TableEngineSpec, TableSpec} +import org.apache.spark.sql.clickhouse.ReadOptions +import org.apache.spark.sql.types.StructType + +import java.time.ZoneId + +case class ScanJobDescription( + node: NodeSpec, + tz: ZoneId, + tableSpec: TableSpec, + tableEngineSpec: TableEngineSpec, + cluster: Option[ClusterSpec], + localTableSpec: Option[TableSpec], + localTableEngineSpec: Option[TableEngineSpec], + readOptions: ReadOptions, + // Below fields will be constructed in ScanBuilder. + readSchema: StructType = new StructType, + // We should pass compiled ClickHouse SQL snippets(or ClickHouse SQL AST data structure) instead of Spark Expression + // into Scan tasks because the check happens in planing phase on driver side. + filtersExpr: String = "1=1", + groupByClause: Option[String] = None, + limit: Option[Int] = None +) { + + def database: String = tableEngineSpec match { + case dist: DistributedEngineSpec if readOptions.convertDistributedToLocal => dist.local_db + case _ => tableSpec.database + } + + def table: String = tableEngineSpec match { + case dist: DistributedEngineSpec if readOptions.convertDistributedToLocal => dist.local_table + case _ => tableSpec.name + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/format/ClickHouseBinaryReader.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/format/ClickHouseBinaryReader.scala new file mode 100644 index 00000000..dd1f9127 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/format/ClickHouseBinaryReader.scala @@ -0,0 +1,155 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.read.format + +import com.clickhouse.client.api.data_formats.internal.BinaryStreamReader +import com.clickhouse.client.api.data_formats.{ClickHouseBinaryFormatReader, RowBinaryWithNamesAndTypesFormatReader} +import com.clickhouse.client.api.query.{GenericRecord, Records} + +import java.util.Collections +import com.clickhouse.data.value.{ + ClickHouseArrayValue, + ClickHouseBoolValue, + ClickHouseDoubleValue, + ClickHouseFloatValue, + ClickHouseIntegerValue, + ClickHouseLongValue, + ClickHouseMapValue, + ClickHouseStringValue +} +import com.clickhouse.data.{ClickHouseArraySequence, ClickHouseRecord, ClickHouseValue} +import com.clickhouse.spark.exception.CHClientException +import com.clickhouse.spark.read.{ClickHouseInputPartition, ClickHouseReader, ScanJobDescription} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import java.io.InputStream +import java.time.{LocalDate, ZoneOffset, ZonedDateTime} +import java.util +import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ + +class ClickHouseBinaryReader( + scanJob: ScanJobDescription, + part: ClickHouseInputPartition +) extends ClickHouseReader[GenericRecord](scanJob, part) { + + override val format: String = "RowBinaryWithNamesAndTypes" + + lazy val streamOutput: Iterator[GenericRecord] = { + val inputString: InputStream = resp.getInputStream + val cbfr: ClickHouseBinaryFormatReader = new RowBinaryWithNamesAndTypesFormatReader( + inputString, + resp.getSettings, + new BinaryStreamReader.DefaultByteBufferAllocator + ) + val r = new Records(resp, cbfr) + r.asScala.iterator + } + + override def decode(record: GenericRecord): InternalRow = { + val size = record.getSchema.getColumns.size() + val values: Array[Any] = new Array[Any](size) + if (readSchema.nonEmpty) { + var i: Int = 0 + while (i < size) { + val v: Object = record.getObject(i + 1) + values(i) = decodeValue(v, readSchema.fields(i)) + i = i + 1 + } + } + new GenericInternalRow(values) + } + + private def decodeValue(value: Object, structField: StructField): Any = { + if (value == null) { + // should we check `structField.nullable`? + return null + } + + structField.dataType match { + case BooleanType => value.asInstanceOf[Boolean] + case ByteType => value.asInstanceOf[Byte] + case ShortType => value.asInstanceOf[Short] +// case IntegerType if value.getClass.toString.equals("class java.lang.Long") => + case IntegerType if value.isInstanceOf[java.lang.Long] => + val v: Integer = Integer.valueOf(value.asInstanceOf[Long].toInt) + v.intValue() + case IntegerType => + value.asInstanceOf[Integer].intValue() + case LongType if value.isInstanceOf[java.math.BigInteger] => + value.asInstanceOf[java.math.BigInteger].longValue() + case LongType => + value.asInstanceOf[Long] + case FloatType => value.asInstanceOf[Float] + case DoubleType => value.asInstanceOf[Double] + case d: DecimalType => + // Java client returns BigInteger for Int256/UInt256, BigDecimal for Decimal types + val dec: BigDecimal = value match { + case bi: java.math.BigInteger => BigDecimal(bi) + case bd: java.math.BigDecimal => BigDecimal(bd) + } + Decimal(dec.setScale(d.scale)) + case TimestampType => + var _instant = value.asInstanceOf[ZonedDateTime].withZoneSameInstant(ZoneOffset.UTC) + TimeUnit.SECONDS.toMicros(_instant.toEpochSecond) + TimeUnit.NANOSECONDS.toMicros(_instant.getNano()) + case StringType => + val strValue = value match { + case uuid: java.util.UUID => uuid.toString + case inet: java.net.InetAddress => inet.getHostAddress + case s: String => s + case enumValue: BinaryStreamReader.EnumValue => enumValue.toString + case _ => value.toString + } + UTF8String.fromString(strValue) + case DateType => + val localDate = value match { + case ld: LocalDate => ld + case zdt: ZonedDateTime => zdt.toLocalDate + case _ => value.asInstanceOf[LocalDate] + } + localDate.toEpochDay.toInt + case BinaryType => value.asInstanceOf[String].getBytes + case ArrayType(_dataType, _nullable) => + // Java client returns BinaryStreamReader.ArrayValue for arrays + val arrayVal = value.asInstanceOf[BinaryStreamReader.ArrayValue] + val arrayValue = arrayVal.getArrayOfObjects().toSeq.asInstanceOf[Seq[Object]] + val convertedArray = Array.tabulate(arrayValue.length) { i => + decodeValue( + arrayValue(i), + StructField("element", _dataType, _nullable) + ) + } + new GenericArrayData(convertedArray) + case MapType(_keyType, _valueType, _valueNullable) => + // Java client returns util.Map (LinkedHashMap or EmptyMap) + val javaMap = value.asInstanceOf[util.Map[Object, Object]] + val convertedMap = + javaMap.asScala.map { case (rawKey, rawValue) => + val decodedKey = decodeValue(rawKey, StructField("key", _keyType, false)) + val decodedValue = + decodeValue(rawValue, StructField("value", _valueType, _valueNullable)) + (decodedKey, decodedValue) + } + ArrayBasedMapData(convertedMap) + case _ => + throw CHClientException(s"Unsupported catalyst type ${structField.name}[${structField.dataType}]") + } + } + +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/format/ClickHouseJsonReader.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/format/ClickHouseJsonReader.scala new file mode 100644 index 00000000..f5a99695 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/read/format/ClickHouseJsonReader.scala @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.read.format + +import com.clickhouse.spark.exception.CHClientException +import com.clickhouse.spark.format.{JSONCompactEachRowWithNamesAndTypesStreamOutput, StreamOutput} +import com.clickhouse.spark.read.{ClickHouseInputPartition, ClickHouseReader, ScanJobDescription} +import com.fasterxml.jackson.databind.JsonNode +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import com.clickhouse.spark.Utils.{dateFmt, dateTimeFmt} + +import java.math.{MathContext, RoundingMode => RM} +import java.time.{LocalDate, ZoneOffset, ZonedDateTime} +import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ +import scala.math.BigDecimal.RoundingMode + +class ClickHouseJsonReader( + scanJob: ScanJobDescription, + part: ClickHouseInputPartition +) extends ClickHouseReader[Array[JsonNode]](scanJob, part) { + + override val format: String = "JSONCompactEachRowWithNamesAndTypes" + + lazy val streamOutput: StreamOutput[Array[JsonNode]] = + JSONCompactEachRowWithNamesAndTypesStreamOutput.deserializeStream(resp.getInputStream) + + override def decode(record: Array[JsonNode]): InternalRow = { + val values: Array[Any] = new Array[Any](record.length) + if (readSchema.nonEmpty) { + var i: Int = 0 + while (i < record.length) { + values(i) = decodeValue(record(i), readSchema.fields(i)) + i = i + 1 + } + } + new GenericInternalRow(values) + } + + private def decodeValue(jsonNode: JsonNode, structField: StructField): Any = { + if (jsonNode == null || jsonNode.isNull) { + // should we check `structField.nullable`? + return null + } + + structField.dataType match { + case BooleanType => jsonNode.asBoolean + case ByteType => jsonNode.asInt.byteValue + case ShortType => jsonNode.asInt.shortValue + case IntegerType => jsonNode.asInt + case LongType => jsonNode.asLong + case FloatType => jsonNode.asDouble.floatValue + case DoubleType => jsonNode.asDouble + case d: DecimalType if jsonNode.isBigDecimal => + Decimal(jsonNode.decimalValue, d.precision, d.scale) + case d: DecimalType if jsonNode.isFloat | jsonNode.isDouble => + Decimal(BigDecimal(jsonNode.doubleValue, new MathContext(d.precision)), d.precision, d.scale) + case d: DecimalType if jsonNode.isInt => + Decimal(BigDecimal(jsonNode.intValue, new MathContext(d.precision)), d.precision, d.scale) + case d: DecimalType if jsonNode.isLong => + Decimal(BigDecimal(jsonNode.longValue, new MathContext(d.precision)), d.precision, d.scale) + case d: DecimalType if jsonNode.isBigInteger => + Decimal(BigDecimal(jsonNode.bigIntegerValue, new MathContext(d.precision)), d.precision, d.scale) + case d: DecimalType => + Decimal(BigDecimal(jsonNode.textValue, new MathContext(d.precision)), d.precision, d.scale) + case TimestampType => + var _instant = + ZonedDateTime.parse(jsonNode.asText, dateTimeFmt.withZone(scanJob.tz)).withZoneSameInstant(ZoneOffset.UTC) + TimeUnit.SECONDS.toMicros(_instant.toEpochSecond) + TimeUnit.NANOSECONDS.toMicros(_instant.getNano()) + case StringType => UTF8String.fromString(jsonNode.asText) + case DateType => LocalDate.parse(jsonNode.asText, dateFmt).toEpochDay.toInt + case BinaryType if jsonNode.isTextual => + // ClickHouse JSON format returns FixedString as plain text, not Base64 + jsonNode.asText.getBytes("UTF-8") + case BinaryType => + // True binary data is Base64 encoded in JSON format + jsonNode.binaryValue + case ArrayType(_dataType, _nullable) => + val _structField = StructField(s"${structField.name}__array_element__", _dataType, _nullable) + new GenericArrayData(jsonNode.asScala.map(decodeValue(_, _structField)).toArray) + case MapType(StringType, _valueType, _valueNullable) => + val mapData = jsonNode.fields.asScala.map { entry => + val _structField = StructField(s"${structField.name}__map_value__", _valueType, _valueNullable) + UTF8String.fromString(entry.getKey) -> decodeValue(entry.getValue, _structField) + }.toMap + ArrayBasedMapData(mapData) + case _ => + throw CHClientException(s"Unsupported catalyst type ${structField.name}[${structField.dataType}]") + } + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/ClickHouseWrite.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/ClickHouseWrite.scala new file mode 100644 index 00000000..da4cc936 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/ClickHouseWrite.scala @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under th e License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.write + +import com.clickhouse.spark.{BytesWrittenMetric, RecordsWrittenMetric, SerializeTimeMetric, WriteTimeMetric} +import com.clickhouse.spark.exception.CHClientException +import com.clickhouse.spark.write.format.{ClickHouseArrowStreamWriter, ClickHouseJsonEachRowWriter} +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf._ +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.metric.CustomMetric +import org.apache.spark.sql.connector.write._ +import com.clickhouse.spark._ + +class ClickHouseWriteBuilder(writeJob: WriteJobDescription) extends WriteBuilder { + + override def build(): Write = new ClickHouseWrite(writeJob) +} + +class ClickHouseWrite( + writeJob: WriteJobDescription +) extends Write + with RequiresDistributionAndOrdering + with SQLConfHelper { + + override def distributionStrictlyRequired: Boolean = writeJob.writeOptions.repartitionStrictly + + override def description: String = + s"ClickHouseWrite(database=${writeJob.targetDatabase(false)}, table=${writeJob.targetTable(false)})})" + + override def requiredDistribution(): Distribution = Distributions.clustered(writeJob.sparkSplits.toArray) + + override def requiredNumPartitions(): Int = conf.getConf(WRITE_REPARTITION_NUM) + + override def requiredOrdering(): Array[SortOrder] = writeJob.sparkSortOrders + + override def toBatch: BatchWrite = new ClickHouseBatchWrite(writeJob) + + override def supportedCustomMetrics(): Array[CustomMetric] = Array( + RecordsWrittenMetric(), + BytesWrittenMetric(), + SerializeTimeMetric(), + WriteTimeMetric() + ) +} + +class ClickHouseBatchWrite( + writeJob: WriteJobDescription +) extends BatchWrite with DataWriterFactory { + + override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = this + + override def commit(messages: Array[WriterCommitMessage]): Unit = {} + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + val format = writeJob.writeOptions.format + format match { + case "json" => new ClickHouseJsonEachRowWriter(writeJob) + case "arrow" => new ClickHouseArrowStreamWriter(writeJob) + case unsupported => throw CHClientException(s"Unsupported write format: $unsupported") + } + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/ClickHouseWriter.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/ClickHouseWriter.scala new file mode 100644 index 00000000..9d17feea --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/ClickHouseWriter.scala @@ -0,0 +1,301 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.write + +import com.clickhouse.spark.{CommitMessage, Logging, TaskMetric, Utils} +import com.clickhouse.client.ClickHouseProtocol +import com.clickhouse.data.ClickHouseCompression +import com.clickhouse.spark.exception.{CHClientException, RetryableCHException} +import com.clickhouse.spark.spec.{DistributedEngineSpec, ShardUtils} +import org.apache.commons.io.IOUtils +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, TransformExpression} +import org.apache.spark.sql.catalyst.expressions.{Projection, SafeProjection} +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.clickhouse.ExprUtils +import org.apache.spark.sql.connector.metric.CustomTaskMetric +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.types._ +import com.clickhouse.spark.Metrics._ +import com.clickhouse.spark.io.{ForwardingOutputStream, ObservableOutputStream} +import com.clickhouse.spark._ +import com.clickhouse.spark.client.{ClusterClient, NodeClient} +import com.clickhouse.spark.exception._ + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, OutputStream} +import java.util.concurrent.atomic.LongAdder +import scala.util.{Failure, Success} + +abstract class ClickHouseWriter(writeJob: WriteJobDescription) + extends DataWriter[InternalRow] with Logging { + + val database: String = writeJob.targetDatabase(writeJob.writeOptions.convertDistributedToLocal) + val table: String = writeJob.targetTable(writeJob.writeOptions.convertDistributedToLocal) + val codec: ClickHouseCompression = writeJob.writeOptions.compressionCodec + val protocol: ClickHouseProtocol = writeJob.node.protocol + + // ClickHouse is nullable sensitive, if the table column is not nullable, we need to cast the column + // to be non-nullable forcibly. + protected val revisedDataSchema: StructType = StructType( + writeJob.dataSetSchema.map { field => + writeJob.tableSchema.find(_.name == field.name) match { + case Some(tableField) if !tableField.nullable && field.nullable => field.copy(nullable = false) + case _ => field + } + } + ) + + protected lazy val shardExpr: Option[Expression] = writeJob.sparkShardExpr match { + case None => None + case Some(v2Expr) => + val catalystExpr = ExprUtils.toCatalyst(v2Expr, writeJob.dataSetSchema.fields, writeJob.functionRegistry) + catalystExpr match { + case BoundReference(_, dataType, _) + if dataType.isInstanceOf[ByteType] // list all integral types here because we can not access `IntegralType` + || dataType.isInstanceOf[ShortType] + || dataType.isInstanceOf[IntegerType] + || dataType.isInstanceOf[LongType] => + Some(catalystExpr) + case BoundReference(_, dataType, _) => + throw CHClientException(s"Invalid data type of sharding field: $dataType") + case TransformExpression(function, _, _) => + function.resultType match { + case ByteType | ShortType | IntegerType | LongType => Some(catalystExpr) + case _ => throw CHClientException(s"Invalid data type of sharding field: ${function.resultType}") + } + case unsupported: Expression => + log.warn(s"Unsupported expression of sharding field: $unsupported") + None + } + } + + protected lazy val shardProjection: Option[expressions.Projection] = shardExpr + .filter(_ => writeJob.writeOptions.convertDistributedToLocal) + .flatMap { + case expr: BoundReference => + Some(SafeProjection.create(Seq(expr))) + case expr @ TransformExpression(function, _, _) => + // result type must be integer class + function.resultType match { + case ByteType => classOf[Byte] + case ShortType => classOf[Short] + case IntegerType => classOf[Int] + case LongType => classOf[Long] + case _ => throw CHClientException(s"Invalid return data type for function ${function.name()}," + + s"sharding field: ${function.resultType}") + } + Some(SafeProjection.create(Seq(ExprUtils.resolveTransformCatalyst(expr, Some(writeJob.tz.getId))))) + } + + // put the node select strategy in executor side because we need to calculate shard and don't know the records + // util DataWriter#write(InternalRow) invoked. + protected lazy val client: Either[ClusterClient, NodeClient] = + writeJob.tableEngineSpec match { + case _: DistributedEngineSpec + if writeJob.writeOptions.useClusterNodesForDistributed || writeJob.writeOptions.convertDistributedToLocal => + val clusterSpec = writeJob.cluster.get + log.info(s"Connect to cluster ${clusterSpec.name}, which has ${clusterSpec.shards.length} shards and " + + s"${clusterSpec.nodes.length} nodes.") + Left(ClusterClient(clusterSpec)) + case _ => + val nodeSpec = writeJob.node + log.info(s"Connect to single node: $nodeSpec") + Right(NodeClient(nodeSpec)) + } + + def nodeClient(shardNum: Option[Int]): NodeClient = client match { + case Left(clusterClient) => clusterClient.node(shardNum) + case Right(nodeClient) => nodeClient + } + + def calcShard(record: InternalRow): Option[Int] = (shardExpr, shardProjection) match { + case (Some(BoundReference(_, dataType, _)), Some(projection)) => + doCalcShard(record, dataType, projection) + case (Some(TransformExpression(function, _, _)), Some(projection)) => + doCalcShard(record, function.resultType, projection) + case _ => None + } + + private def doCalcShard(record: InternalRow, dataType: DataType, projection: Projection): Option[Int] = { + val shardValue = dataType match { + case ByteType => Some(projection(record).getByte(0).toLong) + case ShortType => Some(projection(record).getShort(0).toLong) + case IntegerType => Some(projection(record).getInt(0).toLong) + case LongType => Some(projection(record).getLong(0)) + case _ => None + } + shardValue.map(value => ShardUtils.calcShard(writeJob.cluster.get, value).num) + } + + val _currentBufferedRows = new LongAdder + def currentBufferedRows: Long = _currentBufferedRows.longValue + val _totalRecordsWritten = new LongAdder + def totalRecordsWritten: Long = _totalRecordsWritten.longValue + val _currentRawBytesWritten = new LongAdder + def currentBufferedRawBytes: Long = _currentRawBytesWritten.longValue + val _totalRawBytesWritten = new LongAdder + def totalRawBytesWritten: Long = _totalRawBytesWritten.longValue + val _lastSerializedBytesWritten = new LongAdder + def lastSerializedBytesWritten: Long = _lastSerializedBytesWritten.longValue + val _totalSerializedBytesWritten = new LongAdder + def totalSerializedBytesWritten: Long = _totalSerializedBytesWritten.longValue + val _lastSerializeTime = new LongAdder + def lastSerializeTime: Long = _lastSerializeTime.longValue + val _totalSerializeTime = new LongAdder + def totalSerializeTime: Long = _totalSerializeTime.longValue + val _totalWriteTime = new LongAdder + def totalWriteTime: Long = _totalWriteTime.longValue + + val serializedBuffer = new ByteArrayOutputStream(64 * 1024 * 1024) + + // it is not accurate when using http protocol, because we delegate compression to + // clickhouse http client + private val observableSerializedOutput = new ObservableOutputStream( + serializedBuffer, + Some(_lastSerializedBytesWritten), + Some(_totalSerializedBytesWritten) + ) + + private val compressedForwardingOutput = new ForwardingOutputStream() + + private val observableCompressedOutput = new ObservableOutputStream( + compressedForwardingOutput, + Some(_currentRawBytesWritten), + Some(_totalRawBytesWritten), + Some(_lastSerializeTime), + Some(_totalSerializeTime) + ) + + def output: OutputStream = observableCompressedOutput + + private def renewCompressedOutput(): Unit = { + val compressedOutput = (codec, protocol) match { + case (ClickHouseCompression.NONE, _) => observableSerializedOutput + case (ClickHouseCompression.LZ4, ClickHouseProtocol.HTTP) => + // clickhouse http client forces compressed output stream + // new Lz4OutputStream(observableSerializedOutput, 4 * 1024 * 1024, null) + observableSerializedOutput + case unsupported => + throw CHClientException(s"unsupported compression codec: $unsupported") + } + compressedForwardingOutput.updateDelegate(compressedOutput) + } + + renewCompressedOutput() + + override def currentMetricsValues: Array[CustomTaskMetric] = Array( + TaskMetric(RECORDS_WRITTEN, totalRecordsWritten), + TaskMetric(BYTES_WRITTEN, totalSerializedBytesWritten), + TaskMetric(SERIALIZE_TIME, totalSerializeTime), + TaskMetric(WRITE_TIME, totalWriteTime) + ) + + def format: String + + var currentShardNum: Option[Int] = None + + override def write(record: InternalRow): Unit = { + val shardNum = calcShard(record) + flush(force = shardNum != currentShardNum && currentBufferedRows > 0, currentShardNum) + currentShardNum = shardNum + val (_, serializedTime) = Utils.timeTakenMs(writeRow(record)) + _lastSerializeTime.add(serializedTime) + _totalSerializeTime.add(serializedTime) + _currentBufferedRows.add(1) + flush(force = false, currentShardNum) + } + + def writeRow(record: InternalRow): Unit + + def serialize(): Array[Byte] = { + val (data, serializedTime) = Utils.timeTakenMs(doSerialize()) + _lastSerializeTime.add(serializedTime) + _totalSerializeTime.add(serializedTime) + data + } + + def doSerialize(): Array[Byte] + + def reset(): Unit = { + _currentBufferedRows.reset() + _currentRawBytesWritten.reset() + _lastSerializedBytesWritten.reset() + _lastSerializeTime.reset() + currentShardNum = None + serializedBuffer.reset() + renewCompressedOutput() + } + + def flush(force: Boolean, shardNum: Option[Int]): Unit = + if (force) { + doFlush(shardNum) + } else if (currentBufferedRows >= writeJob.writeOptions.batchSize) { + doFlush(shardNum) + } + + def doFlush(shardNum: Option[Int]): Unit = { + val client = nodeClient(shardNum) + val data = serialize() + var writeTime = 0L + Utils.retry[Unit, RetryableCHException]( + writeJob.writeOptions.maxRetry, + writeJob.writeOptions.retryInterval + ) { + var startWriteTime = System.currentTimeMillis + // codec, + client.syncInsertOutputJSONEachRow(database, table, format, new ByteArrayInputStream(data)) match { + case Right(_) => + writeTime = System.currentTimeMillis - startWriteTime + _totalWriteTime.add(writeTime) + _totalRecordsWritten.add(currentBufferedRows) + case Left(retryable) if writeJob.writeOptions.retryableErrorCodes.contains(retryable.code) => + startWriteTime = System.currentTimeMillis + throw RetryableCHException(retryable.code, retryable.reason, Some(client.nodeSpec)) + case Left(rethrow) => throw rethrow + } + } match { + case Success(_) => + log.info( + s"""Job[${writeJob.queryId}]: batch write completed + |cluster: ${writeJob.cluster.map(_.name).getOrElse("none")}, shard: ${shardNum.getOrElse("none")} + |node: ${client.nodeSpec} + | row count: $currentBufferedRows + | raw size: ${Utils.bytesToString(currentBufferedRawBytes)} + | format: $format + |compression codec: $codec + | serialized size: ${Utils.bytesToString(lastSerializedBytesWritten)} + | serialize time: ${lastSerializeTime}ms + | write time: ${writeTime}ms + |""".stripMargin + ) + reset() + case Failure(rethrow) => throw rethrow + } + } + + override def commit(): WriterCommitMessage = { + flush(currentBufferedRows > 0, currentShardNum) + CommitMessage(s"Job[${writeJob.queryId}]: commit") + } + + override def abort(): Unit = {} + + override def close(): Unit = { + IOUtils.closeQuietly(output) + client match { + case Left(clusterClient) => clusterClient.close() + case Right(nodeClient) => nodeClient.close() + } + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/WriteJobDescription.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/WriteJobDescription.scala new file mode 100644 index 00000000..ca58eb89 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/WriteJobDescription.scala @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.write + +import com.clickhouse.spark.expr.{Expr, FuncExpr, OrderExpr} +import com.clickhouse.spark.func.FunctionRegistry +import com.clickhouse.spark.spec.{ClusterSpec, DistributedEngineSpec, NodeSpec, TableEngineSpec, TableSpec} + +import java.time.ZoneId +import org.apache.spark.sql.clickhouse.{ExprUtils, WriteOptions} +import org.apache.spark.sql.connector.expressions.{Expression, SortOrder, Transform} +import org.apache.spark.sql.types.StructType +import com.clickhouse.spark.spec._ + +case class WriteJobDescription( + queryId: String, + tableSchema: StructType, + metadataSchema: StructType, + dataSetSchema: StructType, + node: NodeSpec, + tz: ZoneId, + tableSpec: TableSpec, + tableEngineSpec: TableEngineSpec, + cluster: Option[ClusterSpec], + localTableSpec: Option[TableSpec], + localTableEngineSpec: Option[TableEngineSpec], + shardingKey: Option[Expr], + partitionKey: Option[List[Expr]], + sortingKey: Option[List[OrderExpr]], + writeOptions: WriteOptions, + functionRegistry: FunctionRegistry +) { + + def targetDatabase(convert2Local: Boolean): String = tableEngineSpec match { + case dist: DistributedEngineSpec if convert2Local => dist.local_db + case _ => tableSpec.database + } + + def targetTable(convert2Local: Boolean): String = tableEngineSpec match { + case dist: DistributedEngineSpec if convert2Local => dist.local_table + case _ => tableSpec.name + } + + def shardingKeyIgnoreRand: Option[Expr] = shardingKey filter { + case FuncExpr("rand", Nil) => false + case _ => true + } + + def sparkShardExpr: Option[Expression] = shardingKeyIgnoreRand match { + case Some(expr) => ExprUtils.toSparkTransformOpt(expr, functionRegistry) + case _ => None + } + + def sparkSplits: Array[Transform] = + if (writeOptions.repartitionByPartition) { + ExprUtils.toSparkSplits( + shardingKeyIgnoreRand, + partitionKey, + functionRegistry + ) + } else { + ExprUtils.toSparkSplits( + shardingKeyIgnoreRand, + None, + functionRegistry + ) + } + + def sparkSortOrders: Array[SortOrder] = { + val _partitionKey = if (writeOptions.localSortByPartition) partitionKey else None + val _sortingKey = if (writeOptions.localSortByKey) sortingKey else None + ExprUtils.toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey, cluster, functionRegistry) + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/format/ClickHouseArrowStreamWriter.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/format/ClickHouseArrowStreamWriter.scala new file mode 100644 index 00000000..b538c489 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/format/ClickHouseArrowStreamWriter.scala @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under th e License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.write.format + +import com.clickhouse.spark.write.{ClickHouseWriter, WriteJobDescription} +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.clickhouse.SparkUtils +import org.apache.spark.sql.execution.arrow.ArrowWriter + +class ClickHouseArrowStreamWriter(writeJob: WriteJobDescription) extends ClickHouseWriter(writeJob) { + + override def format: String = "ArrowStream" + + val allocator: BufferAllocator = SparkUtils.spawnArrowAllocator("writer for ClickHouse") + val arrowSchema: Schema = SparkUtils.toArrowSchema(revisedDataSchema, writeJob.tz.getId) + val root: VectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter: ArrowWriter = ArrowWriter.create(root) + + override def writeRow(record: InternalRow): Unit = arrowWriter.write(record) + + override def doSerialize(): Array[Byte] = { + arrowWriter.finish() + val arrowStreamWriter = new ArrowStreamWriter(root, null, output) + arrowStreamWriter.writeBatch() + arrowStreamWriter.end() + output.flush() + output.close() + serializedBuffer.toByteArray + } + + override def reset(): Unit = { + super.reset() + arrowWriter.reset() + } + + override def close(): Unit = { + root.close() + allocator.close() + super.close() + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/format/ClickHouseJsonEachRowWriter.scala b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/format/ClickHouseJsonEachRowWriter.scala new file mode 100644 index 00000000..756c7d87 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/com/clickhouse/spark/write/format/ClickHouseJsonEachRowWriter.scala @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under th e License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.clickhouse.spark.write.format + +import com.clickhouse.spark.write.{ClickHouseWriter, WriteJobDescription} +import org.apache.commons.io.IOUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.clickhouse.JsonWriter + +class ClickHouseJsonEachRowWriter(writeJob: WriteJobDescription) extends ClickHouseWriter(writeJob) { + + override def format: String = "JSONEachRow" + + val jsonWriter: JsonWriter = new JsonWriter(revisedDataSchema, writeJob.tz, output) + + override def writeRow(record: InternalRow): Unit = jsonWriter.write(record) + + override def doSerialize(): Array[Byte] = { + jsonWriter.flush() + output.close() + serializedBuffer.toByteArray + } + + override def close(): Unit = { + IOUtils.closeQuietly(jsonWriter) + super.close() + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ClickHouseSQLConf.scala b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ClickHouseSQLConf.scala new file mode 100644 index 00000000..39e2bc4a --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ClickHouseSQLConf.scala @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} +import org.apache.spark.sql.internal.SQLConf._ +import com.clickhouse.spark.exception.ClickHouseErrCode._ + +import java.util.concurrent.TimeUnit + +/** + * Run the following command to update the configuration docs. + * UPDATE=1 ./gradlew test --tests=ConfigurationSuite + */ +object ClickHouseSQLConf { + + val WRITE_BATCH_SIZE: ConfigEntry[Int] = + buildConf("spark.clickhouse.write.batchSize") + .doc("The number of records per batch on writing to ClickHouse.") + .version("0.1.0") + .intConf + .checkValue(v => v > 0, "`spark.clickhouse.write.batchSize` should be positive.") + .createWithDefault(10000) + + val WRITE_MAX_RETRY: ConfigEntry[Int] = + buildConf("spark.clickhouse.write.maxRetry") + .doc("The maximum number of write we will retry for a single batch write failed with retryable codes.") + .version("0.1.0") + .intConf + .checkValue(_ >= 0, "Should be 0 or positive. 0 means disable retry.") + .createWithDefault(3) + + val WRITE_RETRY_INTERVAL: ConfigEntry[Long] = + buildConf("spark.clickhouse.write.retryInterval") + .doc("The interval in seconds between write retry.") + .version("0.1.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("10s") + + val WRITE_RETRYABLE_ERROR_CODES: ConfigEntry[Seq[Int]] = + buildConf("spark.clickhouse.write.retryableErrorCodes") + .doc("The retryable error codes returned by ClickHouse server when write failing.") + .version("0.1.0") + .intConf + .toSequence + .checkValue(codes => !codes.exists(_ <= OK.code), "Error code should be positive.") + .createWithDefault(MEMORY_LIMIT_EXCEEDED.code :: Nil) + + val WRITE_REPARTITION_NUM: ConfigEntry[Int] = + buildConf("spark.clickhouse.write.repartitionNum") + .doc("Repartition data to meet the distributions of ClickHouse table is required before writing, " + + "use this conf to specific the repartition number, value less than 1 mean no requirement.") + .version("0.1.0") + .intConf + .createWithDefault(0) + + val WRITE_REPARTITION_BY_PARTITION: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.write.repartitionByPartition") + .doc("Whether to repartition data by ClickHouse partition keys to meet the distributions of " + + "ClickHouse table before writing.") + .version("0.3.0") + .booleanConf + .createWithDefault(true) + + val WRITE_REPARTITION_STRICTLY: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.write.repartitionStrictly") + .doc("If `true`, Spark will strictly distribute incoming records across partitions to satisfy " + + "the required distribution before passing the records to the data source table on write. " + + "Otherwise, Spark may apply certain optimizations to speed up the query but break the " + + "distribution requirement. Note, this configuration requires SPARK-37523(available in " + + "Spark 3.4), w/o this patch, it always acts as `true`.") + .version("0.3.0") + .booleanConf + .createWithDefault(false) + + val WRITE_DISTRIBUTED_USE_CLUSTER_NODES: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.write.distributed.useClusterNodes") + .doc("Write to all nodes of cluster when writing Distributed table.") + .version("0.1.0") + .booleanConf + .createWithDefault(true) + + val READ_DISTRIBUTED_USE_CLUSTER_NODES: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.read.distributed.useClusterNodes") + .doc("Read from all nodes of cluster when reading Distributed table.") + .internal + .version("0.1.0") + .booleanConf + .checkValue(_ == false, s"`spark.clickhouse.read.distributed.useClusterNodes` is not support yet.") + .createWithDefault(false) + + val WRITE_DISTRIBUTED_CONVERT_LOCAL: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.write.distributed.convertLocal") + .doc("When writing Distributed table, write local table instead of itself. " + + "If `true`, ignore `spark.clickhouse.write.distributed.useClusterNodes`.") + .version("0.1.0") + .booleanConf + .createWithDefault(false) + + val READ_DISTRIBUTED_CONVERT_LOCAL: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.read.distributed.convertLocal") + .doc("When reading Distributed table, read local table instead of itself. " + + s"If `true`, ignore `${READ_DISTRIBUTED_USE_CLUSTER_NODES.key}`.") + .version("0.1.0") + .booleanConf + .createWithDefault(true) + + val READ_SPLIT_BY_PARTITION_ID: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.read.splitByPartitionId") + .doc("If `true`, construct input partition filter by virtual column `_partition_id`, " + + "instead of partition value. There are known bugs to assemble SQL predication by " + + "partition value. This feature requires ClickHouse Server v21.6+") + .version("0.4.0") + .booleanConf + .createWithDefault(true) + + val WRITE_LOCAL_SORT_BY_PARTITION: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.write.localSortByPartition") + .doc(s"If `true`, do local sort by partition before writing. If not set, it equals to " + + s"`${WRITE_REPARTITION_BY_PARTITION.key}`.") + .version("0.3.0") + .fallbackConf(WRITE_REPARTITION_BY_PARTITION) + + val WRITE_LOCAL_SORT_BY_KEY: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.write.localSortByKey") + .doc("If `true`, do local sort by sort keys before writing.") + .version("0.3.0") + .booleanConf + .createWithDefault(true) + + val IGNORE_UNSUPPORTED_TRANSFORM: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.ignoreUnsupportedTransform") + .doc("ClickHouse supports using complex expressions as sharding keys or partition values, " + + "e.g. `cityHash64(col_1, col_2)`, and those can not be supported by Spark now. If `true`, " + + "ignore the unsupported expressions, otherwise fail fast w/ an exception. Note, when " + + s"`${WRITE_DISTRIBUTED_CONVERT_LOCAL.key}` is enabled, ignore unsupported sharding keys " + + "may corrupt the data.") + .version("0.4.0") + .booleanConf + .createWithDefault(false) + + val READ_COMPRESSION_CODEC: ConfigEntry[String] = + buildConf("spark.clickhouse.read.compression.codec") + .doc("The codec used to decompress data for reading. Supported codecs: none, lz4.") + .version("0.5.0") + .stringConf + .createWithDefault("lz4") + + val WRITE_COMPRESSION_CODEC: ConfigEntry[String] = + buildConf("spark.clickhouse.write.compression.codec") + .doc("The codec used to compress data for writing. Supported codecs: none, lz4.") + .version("0.3.0") + .stringConf + .createWithDefault("lz4") + + val READ_FORMAT: ConfigEntry[String] = + buildConf("spark.clickhouse.read.format") + .doc("Serialize format for reading. Supported formats: json, binary") + .version("0.6.0") + .stringConf + .transform(_.toLowerCase) + .createWithDefault("json") + + val RUNTIME_FILTER_ENABLED: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.read.runtimeFilter.enabled") + .doc("Enable runtime filter for reading.") + .version("0.8.0") + .booleanConf + .createWithDefault(false) + + val WRITE_FORMAT: ConfigEntry[String] = + buildConf("spark.clickhouse.write.format") + .doc("Serialize format for writing. Supported formats: json, arrow") + .version("0.4.0") + .stringConf + .transform { + case s if s equalsIgnoreCase "JSONEachRow" => "json" + case s if s equalsIgnoreCase "ArrowStream" => "arrow" + case s => s.toLowerCase + } + .createWithDefault("arrow") + + val USE_NULLABLE_QUERY_SCHEMA: ConfigEntry[Boolean] = + buildConf("spark.clickhouse.useNullableQuerySchema") + .doc("If `true`, mark all the fields of the query schema as nullable when executing " + + "`CREATE/REPLACE TABLE ... AS SELECT ...` on creating the table. Note, this " + + "configuration requires SPARK-43390(available in Spark 3.5), w/o this patch, " + + "it always acts as `true`.") + .version("0.8.0") + .booleanConf + .createWithDefault(false) + + val READ_FIXED_STRING_AS: ConfigEntry[String] = + buildConf("spark.clickhouse.read.fixedStringAs") + .doc("Read ClickHouse FixedString type as the specified Spark data type. Supported types: binary, string") + .version("0.8.0") + .stringConf + .transform(_.toLowerCase) + .createWithDefault("binary") + + val READ_SETTINGS: OptionalConfigEntry[String] = + buildConf("spark.clickhouse.read.settings") + .doc("Settings when read from ClickHouse. e.g. `final=1, max_execution_time=5`") + .version("0.9.0") + .stringConf + .transform(_.toLowerCase) + .createOptional + +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala new file mode 100644 index 00000000..d760cdc1 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import com.clickhouse.spark.exception.CHClientException +import com.clickhouse.spark.expr.{Expr, FieldRef, FuncExpr, OrderExpr, SQLExpr, StringLiteral} +import com.clickhouse.spark.func.FunctionRegistry +import com.clickhouse.spark.spec.ClusterSpec +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, NoSuchFunctionException, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, ListQuery, Literal} +import org.apache.spark.sql.catalyst.expressions.{TimeZoneAwareExpression, TransformExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION} +import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper} +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.connector.expressions.Expressions._ +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder, _} +import org.apache.spark.sql.types.{StructField, StructType} +import com.clickhouse.spark.expr._ + +import scala.util.{Failure, Success, Try} + +object ExprUtils extends SQLConfHelper with Serializable { + + def toSparkPartitions( + partitionKey: Option[List[Expr]], + functionRegistry: FunctionRegistry + ): Array[Transform] = + partitionKey.seq.flatten.flatten(toSparkTransformOpt(_, functionRegistry)).toArray + + def toSparkSplits( + shardingKey: Option[Expr], + partitionKey: Option[List[Expr]], + functionRegistry: FunctionRegistry + ): Array[Transform] = + (shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_, functionRegistry)).toArray + + def toSparkSortOrders( + shardingKeyIgnoreRand: Option[Expr], + partitionKey: Option[List[Expr]], + sortingKey: Option[List[OrderExpr]], + cluster: Option[ClusterSpec], + functionRegistry: FunctionRegistry + ): Array[V2SortOrder] = + toSparkSplits( + shardingKeyIgnoreRand, + partitionKey, + functionRegistry + ).map(Expressions.sort(_, SortDirection.ASCENDING)) ++: + sortingKey.seq.flatten.flatten { case OrderExpr(expr, asc, nullFirst) => + val direction = if (asc) SortDirection.ASCENDING else SortDirection.DESCENDING + val nullOrder = if (nullFirst) NullOrdering.NULLS_FIRST else NullOrdering.NULLS_LAST + toSparkTransformOpt(expr, functionRegistry).map(trans => + Expressions.sort(trans, direction, nullOrder) + ) + }.toArray + + private def loadV2FunctionOpt( + name: String, + args: Seq[Expression], + functionRegistry: FunctionRegistry + ): Option[BoundFunction] = { + def loadFunction(ident: Identifier): UnboundFunction = + functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident)) + val inputType = StructType(args.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + try { + val unbound = loadFunction(Identifier.of(Array.empty, name)) + Some(unbound.bind(inputType)) + } catch { + case e: NoSuchFunctionException => + throw e + case _: UnsupportedOperationException if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => + None + case e: UnsupportedOperationException => + throw new AnalysisException( + errorClass = "UNSUPPORTED_OPERATION", + messageParameters = Map("operation" -> e.getMessage), + cause = Some(e) + ) + } + } + + def resolveTransformCatalyst( + catalystExpr: Expression, + timeZoneId: Option[String] = None + ): Expression = + new TypeCoercionExecutor(timeZoneId) + .execute(DummyLeafNode(resolveTransformExpression(catalystExpr))) + .asInstanceOf[DummyLeafNode].expr + + private case class DummyLeafNode(expr: Expression) extends LeafNode { + override def output: Seq[Attribute] = Nil + } + + private class CustomResolveTimeZone(timeZoneId: Option[String]) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveExpressionsWithPruning(_.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION)) { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(timeZoneId.getOrElse(conf.sessionLocalTimeZone)) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) + } + } + + private class TypeCoercionExecutor(timeZoneId: Option[String]) extends RuleExecutor[LogicalPlan] { + override val batches = + Batch("Resolve TypeCoercion", FixedPoint(1), typeCoercionRules: _*) :: + Batch("Resolve TimeZone", FixedPoint(1), new CustomResolveTimeZone(timeZoneId)) :: Nil + } + + private def resolveTransformExpression(expr: Expression): Expression = expr.transform { + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => + V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) + } + + private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) { + AnsiTypeCoercion.typeCoercionRules + } else { + TypeCoercion.typeCoercionRules + } + + def toCatalyst( + v2Expr: V2Expression, + fields: Array[StructField], + functionRegistry: FunctionRegistry + ): Expression = + v2Expr match { + case IdentityTransform(ref) => toCatalyst(ref, fields, functionRegistry) + case ref: NamedReference if ref.fieldNames.length == 1 => + val (field, ordinal) = fields + .zipWithIndex + .find { case (field, _) => field.name == ref.fieldNames.head } + .getOrElse(throw CHClientException(s"Invalid field reference: $ref")) + BoundReference(ordinal, field.dataType, field.nullable) + case t: Transform => + val catalystArgs = t.arguments().map(toCatalyst(_, fields, functionRegistry)) + loadV2FunctionOpt(t.name(), catalystArgs, functionRegistry) + .map(bound => TransformExpression(bound, catalystArgs)).getOrElse { + throw CHClientException(s"Unsupported expression: $v2Expr") + } + case literal: LiteralValue[Any] => expressions.Literal(literal.value) + case _ => throw CHClientException( + s"Unsupported expression: $v2Expr" + ) + } + + def toSparkTransformOpt(expr: Expr, functionRegistry: FunctionRegistry): Option[Transform] = + Try(toSparkExpression(expr, functionRegistry)) match { + // need this function because spark `Table`'s `partitioning` field should be `Transform` + case Success(t: Transform) => Some(t) + case Success(_) => None + case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None + case Failure(rethrow) => throw new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE.TRANSFORM_EXPRESSION", + messageParameters = Map("transform" -> rethrow.getMessage), + cause = Some(rethrow) + ) + } + + def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression = + expr match { + case FieldRef(col) => identity(col) + case StringLiteral(value) => literal(value) // TODO LiteralTransform + case FuncExpr("rand", Nil) => apply("rand") + case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col) + case FuncExpr(funName, args) if functionRegistry.clickHouseToSparkFunc.contains(funName) => + apply(functionRegistry.clickHouseToSparkFunc(funName), args.map(toSparkExpression(_, functionRegistry)): _*) + case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported") + } + + def toClickHouse( + transform: Transform, + functionRegistry: FunctionRegistry + ): Expr = transform match { + case IdentityTransform(fieldRefs) => FieldRef(fieldRefs.describe) + case ApplyTransform(name, args) if functionRegistry.sparkToClickHouseFunc.contains(name) => + FuncExpr(functionRegistry.sparkToClickHouseFunc(name), args.map(arg => SQLExpr(arg.describe)).toList) + case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket") + case other: Transform => throw CHClientException(s"Unsupported transform: $other") + } + + def inferTransformSchema( + primarySchema: StructType, + secondarySchema: StructType, + transform: Transform, + functionRegistry: FunctionRegistry + ): StructField = transform match { + case IdentityTransform(FieldReference(Seq(col))) => primarySchema.find(_.name == col) + .orElse(secondarySchema.find(_.name == col)) + .getOrElse(throw CHClientException(s"Invalid partition column: $col")) + case t @ ApplyTransform(transformName, _) if functionRegistry.load(transformName).isDefined => + val resType = functionRegistry.load(transformName) match { + case Some(f: ScalarFunction[_]) => f.resultType + case other => throw CHClientException(s"Unsupported function: $other") + } + StructField(t.toString, resType) + case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket") + case other: Transform => throw CHClientException(s"Unsupported transform: $other") + } +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/JsonWriter.scala b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/JsonWriter.scala new file mode 100644 index 00000000..8f46a67f --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/JsonWriter.scala @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.{JSONOptions, JacksonGenerator} +import org.apache.spark.sql.types.StructType + +import java.io.{Closeable, Flushable, OutputStream, OutputStreamWriter} +import java.nio.charset.StandardCharsets +import java.time.ZoneId + +class JsonWriter(schema: StructType, tz: ZoneId, output: OutputStream) extends Closeable with Flushable { + private val option: Map[String, String] = Map( + "timestampFormat" -> "yyyy-MM-dd HH:mm:ss", + "timestampNTZFormat" -> "yyyy-MM-dd HH:mm:ss" + ) + private val utf8Writer = new OutputStreamWriter(output, StandardCharsets.UTF_8) + private val jsonWriter = new JacksonGenerator(schema, utf8Writer, new JSONOptions(option, tz.getId)) + + def write(row: InternalRow): Unit = { + jsonWriter.write(row) + jsonWriter.writeLineEnding() + } + + override def flush(): Unit = jsonWriter.flush() + + override def close(): Unit = jsonWriter.close() +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SchemaUtils.scala b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SchemaUtils.scala new file mode 100644 index 00000000..7b43ee33 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SchemaUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import com.clickhouse.data.ClickHouseDataType._ +import com.clickhouse.data.{ClickHouseColumn, ClickHouseDataType} +import com.clickhouse.spark.exception.CHClientException +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.READ_FIXED_STRING_AS + +object SchemaUtils extends SQLConfHelper { + + def fromClickHouseType(chColumn: ClickHouseColumn): (DataType, Boolean) = { + val catalystType = chColumn.getDataType match { + case Nothing => NullType + case Bool => BooleanType + case String | JSON | UUID | Enum8 | Enum16 | IPv4 | IPv6 => StringType + case FixedString => + conf.getConf(READ_FIXED_STRING_AS) match { + case "binary" => BinaryType + case "string" => StringType + case unsupported => throw CHClientException(s"Unsupported fixed string read format mapping: $unsupported") + } + case Int8 => ByteType + case UInt8 | Int16 => ShortType + case UInt16 | Int32 => IntegerType + case UInt32 | Int64 | UInt64 => LongType + case Int128 | UInt128 | Int256 | UInt256 => DecimalType(38, 0) + case Float32 => FloatType + case Float64 => DoubleType + case Date | Date32 => DateType + case DateTime | DateTime32 | DateTime64 => TimestampType + case ClickHouseDataType.Decimal if chColumn.getScale <= 38 => + DecimalType(chColumn.getPrecision, chColumn.getScale) + case Decimal32 => DecimalType(9, chColumn.getScale) + case Decimal64 => DecimalType(18, chColumn.getScale) + case Decimal128 => DecimalType(38, chColumn.getScale) + case IntervalYear => YearMonthIntervalType(YearMonthIntervalType.YEAR) + case IntervalMonth => YearMonthIntervalType(YearMonthIntervalType.MONTH) + case IntervalDay => DayTimeIntervalType(DayTimeIntervalType.DAY) + case IntervalHour => DayTimeIntervalType(DayTimeIntervalType.HOUR) + case IntervalMinute => DayTimeIntervalType(DayTimeIntervalType.MINUTE) + case IntervalSecond => DayTimeIntervalType(DayTimeIntervalType.SECOND) + case Array => + val elementChCols = chColumn.getNestedColumns + assert(elementChCols.size == 1) + val (elementType, elementNullable) = fromClickHouseType(elementChCols.get(0)) + ArrayType(elementType, elementNullable) + case Map => + val kvChCols = chColumn.getNestedColumns + assert(kvChCols.size == 2) + val (keyChType, valueChType) = (kvChCols.get(0), kvChCols.get(1)) + val (keyType, keyNullable) = fromClickHouseType(keyChType) + require( + !keyNullable, + s"Illegal type: ${keyChType.getOriginalTypeName}, the key type of Map should not be nullable" + ) + val (valueType, valueNullable) = fromClickHouseType(valueChType) + MapType(keyType, valueType, valueNullable) + case Object | Nested | Tuple | Point | Polygon | MultiPolygon | Ring | IntervalQuarter | IntervalWeek | + Decimal256 | AggregateFunction | SimpleAggregateFunction => + throw CHClientException(s"Unsupported type: ${chColumn.getOriginalTypeName}") + } + (catalystType, chColumn.isNullable) + } + + def toClickHouseType(catalystType: DataType, nullable: Boolean): String = + catalystType match { + case BooleanType => maybeNullable("UInt8", nullable) + case ByteType => maybeNullable("Int8", nullable) + case ShortType => maybeNullable("Int16", nullable) + case IntegerType => maybeNullable("Int32", nullable) + case LongType => maybeNullable("Int64", nullable) + case FloatType => maybeNullable("Float32", nullable) + case DoubleType => maybeNullable("Float64", nullable) + case StringType => maybeNullable("String", nullable) + case VarcharType(_) => maybeNullable("String", nullable) + case CharType(_) => maybeNullable("String", nullable) // TODO: maybe FixString? + case DateType => maybeNullable("Date", nullable) + case TimestampType => maybeNullable("DateTime", nullable) + case DecimalType.Fixed(p, s) => maybeNullable(s"Decimal($p, $s)", nullable) + case ArrayType(elemType, containsNull) => s"Array(${toClickHouseType(elemType, containsNull)})" + // TODO currently only support String as key + case MapType(keyType, valueType, valueContainsNull) if keyType.isInstanceOf[StringType] => + s"Map(${toClickHouseType(keyType, nullable = false)}, ${toClickHouseType(valueType, valueContainsNull)})" + case _ => throw CHClientException(s"Unsupported type: $catalystType") + } + + def fromClickHouseSchema(chSchema: Seq[(String, String)]): StructType = { + val structFields = chSchema.map { case (name, maybeNullableType) => + val chCols = ClickHouseColumn.parse(s"`$name` $maybeNullableType") + assert(chCols.size == 1) + val (sparkType, nullable) = fromClickHouseType(chCols.get(0)) + StructField(name, sparkType, nullable) + } + StructType(structFields) + } + + def toClickHouseSchema(catalystSchema: StructType): Seq[(String, String, String)] = + catalystSchema.fields + .map { field => + val chType = toClickHouseType(field.dataType, field.nullable) + (field.name, chType, field.getComment().map(c => s" COMMENT '$c'").getOrElse("")) + } + + private[clickhouse] def maybeNullable(chType: String, nullable: Boolean): String = + if (nullable) wrapNullable(chType) else chType + + private[clickhouse] def wrapNullable(chType: String): String = s"Nullable($chType)" +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SparkOptions.scala b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SparkOptions.scala new file mode 100644 index 00000000..b473d7db --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SparkOptions.scala @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import com.clickhouse.data.ClickHouseCompression +import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.clickhouse.ClickHouseSQLConf._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.time.Duration +import java.util.{Map => JMap} + +trait SparkOptions extends SQLConfHelper with Serializable { + protected def options: CaseInsensitiveStringMap + + protected def eval[T](key: String, entry: ConfigEntry[T]): T = + Option(options.get(key)).map(entry.valueConverter).getOrElse(conf.getConf(entry)) +} + +class ReadOptions(_options: JMap[String, String]) extends SparkOptions { + + override protected def options: CaseInsensitiveStringMap = new CaseInsensitiveStringMap(_options) + + def useClusterNodesForDistributed: Boolean = + eval(READ_DISTRIBUTED_USE_CLUSTER_NODES.key, READ_DISTRIBUTED_USE_CLUSTER_NODES) + + def convertDistributedToLocal: Boolean = + eval(READ_DISTRIBUTED_CONVERT_LOCAL.key, READ_DISTRIBUTED_CONVERT_LOCAL) + + def splitByPartitionId: Boolean = + eval(READ_SPLIT_BY_PARTITION_ID.key, READ_SPLIT_BY_PARTITION_ID) + + def compressionCodec: ClickHouseCompression = + ClickHouseCompression.fromEncoding(eval(READ_COMPRESSION_CODEC.key, READ_COMPRESSION_CODEC)) + + def format: String = + eval(READ_FORMAT.key, READ_FORMAT) + + def runtimeFilterEnabled: Boolean = + eval(RUNTIME_FILTER_ENABLED.key, RUNTIME_FILTER_ENABLED) +} + +class WriteOptions(_options: JMap[String, String]) extends SparkOptions { + + override protected def options: CaseInsensitiveStringMap = new CaseInsensitiveStringMap(_options) + + def batchSize: Int = eval(WRITE_BATCH_SIZE.key, WRITE_BATCH_SIZE) + + def maxRetry: Int = eval(WRITE_MAX_RETRY.key, WRITE_MAX_RETRY) + + def retryInterval: Duration = + Duration.ofSeconds(eval(WRITE_RETRY_INTERVAL.key, WRITE_RETRY_INTERVAL)) + + def retryableErrorCodes: Seq[Int] = eval(WRITE_RETRYABLE_ERROR_CODES.key, WRITE_RETRYABLE_ERROR_CODES) + + def repartitionNum: Int = eval(WRITE_REPARTITION_NUM.key, WRITE_REPARTITION_NUM) + + def repartitionByPartition: Boolean = + eval(WRITE_REPARTITION_BY_PARTITION.key, WRITE_REPARTITION_BY_PARTITION) + + def repartitionStrictly: Boolean = + eval(WRITE_REPARTITION_STRICTLY.key, WRITE_REPARTITION_STRICTLY) + + def useClusterNodesForDistributed: Boolean = + eval(WRITE_DISTRIBUTED_USE_CLUSTER_NODES.key, WRITE_DISTRIBUTED_USE_CLUSTER_NODES) + + def convertDistributedToLocal: Boolean = + eval(WRITE_DISTRIBUTED_CONVERT_LOCAL.key, WRITE_DISTRIBUTED_CONVERT_LOCAL) + + def localSortByPartition: Boolean = + eval(WRITE_LOCAL_SORT_BY_PARTITION.key, WRITE_LOCAL_SORT_BY_PARTITION) + + def localSortByKey: Boolean = + eval(WRITE_LOCAL_SORT_BY_KEY.key, WRITE_LOCAL_SORT_BY_KEY) + + def compressionCodec: ClickHouseCompression = + ClickHouseCompression.fromEncoding(eval(WRITE_COMPRESSION_CODEC.key, WRITE_COMPRESSION_CODEC)) + + def format: String = + eval(WRITE_FORMAT.key, WRITE_FORMAT) +} diff --git a/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SparkUtils.scala b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SparkUtils.scala new file mode 100644 index 00000000..3cefb0e1 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SparkUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.VersionUtils + +object SparkUtils { + + lazy val MAJOR_MINOR_VERSION: (Int, Int) = VersionUtils.majorMinorVersion(SPARK_VERSION) + + def toArrowSchema(schema: StructType, timeZoneId: String): Schema = + ArrowUtils.toArrowSchema(schema, timeZoneId, true, false) + + def spawnArrowAllocator(name: String): BufferAllocator = + ArrowUtils.rootAllocator.newChildAllocator(name, 0, Long.MaxValue) +} diff --git a/spark-4.0/clickhouse-spark/src/test/resources/log4j2.xml b/spark-4.0/clickhouse-spark/src/test/resources/log4j2.xml new file mode 100644 index 00000000..3e2579f1 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/test/resources/log4j2.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/ClickHouseHelperSuite.scala b/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/ClickHouseHelperSuite.scala new file mode 100644 index 00000000..063b500a --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/ClickHouseHelperSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import com.clickhouse.spark.ClickHouseHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.scalatest.funsuite.AnyFunSuite + +import scala.collection.JavaConverters._ + +class ClickHouseHelperSuite extends AnyFunSuite with ClickHouseHelper { + + test("buildNodeSpec") { + val nodeSpec = buildNodeSpec( + new CaseInsensitiveStringMap(Map( + "database" -> "testing", + "option.database" -> "production", + "option.ssl" -> "true" + ).asJava) + ) + assert(nodeSpec.database === "testing") + assert(nodeSpec.options.get("ssl") === "true") + } +} diff --git a/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/ConfigurationSuite.scala b/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/ConfigurationSuite.scala new file mode 100644 index 00000000..c5119c40 --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/ConfigurationSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import com.clickhouse.spark.Utils +import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.sql.internal.SQLConf +import org.scalatest.funsuite.AnyFunSuite + +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path, Paths, StandardOpenOption} +import java.util +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.{universe => ru} + +/** + * End-to-end test cases for configuration documentation. + * + * The golden result file is "docs/configurations/02_sql_configurations.md". + * + * To run the entire test suite: + * {{{ + * ./gradlew test --tests=ConfigurationSuite + * }}} + * + * To re-generate golden files for entire suite, run: + * {{{ + * UPDATE=1 ./gradlew test --tests=ConfigurationSuite + * }}} + */ +class ConfigurationSuite extends AnyFunSuite { + + private val configurationsMarkdown = Paths + .get(Utils.getCodeSourceLocation(getClass).split("clickhouse-spark").head) + .resolve("..") + .resolve("docs") + .resolve("configurations") + .resolve("02_sql_configurations.md") + .normalize + + test("docs") { + ClickHouseSQLConf + + val newOutput = new ArrayBuffer[String] + newOutput += "---" + newOutput += "license: |" + newOutput += " Licensed under the Apache License, Version 2.0 (the \"License\");" + newOutput += " you may not use this file except in compliance with the License." + newOutput += " You may obtain a copy of the License at" + newOutput += " " + newOutput += " https://www.apache.org/licenses/LICENSE-2.0" + newOutput += " " + newOutput += " Unless required by applicable law or agreed to in writing, software" + newOutput += " distributed under the License is distributed on an \"AS IS\" BASIS," + newOutput += " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied." + newOutput += " See the License for the specific language governing permissions and" + newOutput += " limitations under the License." + newOutput += "---" + newOutput += "" + newOutput += "" + newOutput += "|Key | Default | Description | Since" + newOutput += "|--- | ------- | ----------- | -----" + + val sqlConfEntries: Seq[ConfigEntry[_]] = + ru.runtimeMirror(SQLConf.getClass.getClassLoader) + .reflect(SQLConf) + .reflectField(ru.typeOf[SQLConf.type].decl(ru.TermName("sqlConfEntries")).asTerm) + .get.asInstanceOf[util.Map[String, ConfigEntry[_]]] + .asScala.values.toSeq + + sqlConfEntries + .filter(entry => entry.key.startsWith("spark.clickhouse.") && entry.isPublic) + .sortBy(_.key) + .foreach { entry => + val seq = Seq( + s"${entry.key}", + s"${entry.defaultValueString}", + s"${entry.doc}", + s"${entry.version}" + ) + newOutput += seq.mkString("|") + } + newOutput += "" + + verifyOutput(configurationsMarkdown, newOutput, getClass.getCanonicalName) + } + + def verifyOutput(goldenFile: Path, newOutput: ArrayBuffer[String], agent: String): Unit = + if (System.getenv("UPDATE") == "1") { + val writer = Files.newBufferedWriter( + goldenFile, + StandardCharsets.UTF_8, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.CREATE + ) + try newOutput.foreach { line => + writer.write(line) + writer.newLine() + } + finally writer.close() + } else { + val expected = Files.readAllLines(goldenFile).asScala + val hint = s"$goldenFile is out of date, please update the golden file with " + + s"UPDATE=1 ./gradlew test --tests=ConfigurationSuite" + assert(newOutput.size === expected.size, hint) + + newOutput.zip(expected).foreach { case (out, in) => assert(out === in, hint) } + } +} diff --git a/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/FunctionRegistrySuite.scala b/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/FunctionRegistrySuite.scala new file mode 100644 index 00000000..df82a95d --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/FunctionRegistrySuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import com.clickhouse.spark.func.{ + ClickHouseXxHash64, + ClickhouseEquivFunction, + CompositeFunctionRegistry, + DynamicFunctionRegistry, + StaticFunctionRegistry +} +import org.scalatest.funsuite.AnyFunSuite +import com.clickhouse.spark.func._ + +class FunctionRegistrySuite extends AnyFunSuite { + + val staticFunctionRegistry: StaticFunctionRegistry.type = StaticFunctionRegistry + val dynamicFunctionRegistry = new DynamicFunctionRegistry + dynamicFunctionRegistry.register("ck_xx_hash64", ClickHouseXxHash64) + dynamicFunctionRegistry.register("clickhouse_xxHash64", ClickHouseXxHash64) + + test("check StaticFunctionRegistry mappings") { + assert(staticFunctionRegistry.sparkToClickHouseFunc.forall { case (k, v) => + staticFunctionRegistry.load(k).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(v) + }) + assert(staticFunctionRegistry.clickHouseToSparkFunc.forall { case (k, v) => + staticFunctionRegistry.load(v).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(k) + }) + } + + test("check DynamicFunctionRegistry mappings") { + assert(dynamicFunctionRegistry.sparkToClickHouseFunc.forall { case (k, v) => + dynamicFunctionRegistry.load(k).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(v) + }) + assert(dynamicFunctionRegistry.clickHouseToSparkFunc.forall { case (k, v) => + dynamicFunctionRegistry.load(v).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(k) + }) + } + + test("check CompositeFunctionRegistry mappings") { + val compositeFunctionRegistry = + new CompositeFunctionRegistry(Array(staticFunctionRegistry, dynamicFunctionRegistry)) + assert(compositeFunctionRegistry.sparkToClickHouseFunc.forall { case (k, v) => + compositeFunctionRegistry.load(k).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(v) + }) + assert(compositeFunctionRegistry.clickHouseToSparkFunc.forall { case (k, v) => + compositeFunctionRegistry.load(v).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(k) + }) + } +} diff --git a/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/SchemaUtilsSuite.scala b/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/SchemaUtilsSuite.scala new file mode 100644 index 00000000..a16928cc --- /dev/null +++ b/spark-4.0/clickhouse-spark/src/test/scala/org/apache/spark/sql/clickhouse/SchemaUtilsSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.clickhouse + +import com.clickhouse.data.ClickHouseColumn +import org.apache.spark.sql.clickhouse.SchemaUtils._ +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite + +class SchemaUtilsSuite extends AnyFunSuite { + + case class TestBean(chTypeStr: String, sparkType: DataType, nullable: Boolean) + + private def assertPositive(positives: TestBean*): Unit = + positives.foreach { case TestBean(chTypeStr, expectedSparkType, expectedNullable) => + test(s"ch2spark - $chTypeStr") { + val chCols = ClickHouseColumn.parse(s"`col` $chTypeStr") + assert(chCols.size == 1) + val (actualSparkType, actualNullable) = fromClickHouseType(chCols.get(0)) + assert(actualSparkType === expectedSparkType) + assert(actualNullable === expectedNullable) + } + } + + private def assertNegative(negatives: String*): Unit = negatives.foreach { chTypeStr => + test(s"ch2spark - $chTypeStr") { + intercept[Exception] { + ClickHouseColumn.parse(s"`col` $chTypeStr") + val chCols = ClickHouseColumn.parse(s"`col` $chTypeStr") + assert(chCols.size == 1) + fromClickHouseType(chCols.get(0)) + } + } + } + + assertPositive( + TestBean( + "Array(String)", + ArrayType(StringType, containsNull = false), + nullable = false + ), + TestBean( + "Array(Nullable(String))", + ArrayType(StringType, containsNull = true), + nullable = false + ), + TestBean( + "Array(Array(String))", + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false), + nullable = false + ) + ) + + assertNegative( + "array(String)", + "Array(String" + ) + + assertPositive( + TestBean( + "Map(String, String)", + MapType(StringType, StringType, valueContainsNull = false), + nullable = false + ), + TestBean( + "Map(String,Int32)", + MapType(StringType, IntegerType, valueContainsNull = false), + nullable = false + ), + TestBean( + "Map(String,Nullable(UInt32))", + MapType(StringType, LongType, valueContainsNull = true), + nullable = false + ) + ) + + assertNegative( + "Map(String,)" + ) + + assertPositive( + TestBean( + "Date", + DateType, + nullable = false + ), + TestBean( + "DateTime", + TimestampType, + nullable = false + ), + TestBean( + "DateTime(Asia/Shanghai)", + TimestampType, + nullable = false + ), + TestBean( + "DateTime64", + TimestampType, + nullable = false + ) + // TestBean( + // "DateTime64(Europe/Moscow)", + // TimestampType, + // nullable = false + // ), + ) + + assertNegative( + "DT" + ) + + assertPositive( + TestBean( + "Decimal(2,1)", + DecimalType(2, 1), + nullable = false + ), + TestBean( + "Decimal32(5)", + DecimalType(9, 5), + nullable = false + ), + TestBean( + "Decimal64(5)", + DecimalType(18, 5), + nullable = false + ), + TestBean( + "Decimal128(5)", + DecimalType(38, 5), + nullable = false + ) + ) + + assertNegative( + "Decimal", // overflow + "Decimal256(5)", // overflow + "Decimal(String" + // "Decimal32(5" + ) + + assertPositive( + TestBean( + "String", + StringType, + nullable = false + ), + TestBean( + "FixedString(5)", + BinaryType, + nullable = false + ), + TestBean( + "LowCardinality(String)", + StringType, + nullable = false + ), + TestBean( + "LowCardinality(FixedString(5))", + BinaryType, + nullable = false + ), + TestBean( + "LowCardinality(Int32)", // illegal actually + IntegerType, + nullable = false + ) + ) + + assertNegative("fixedString(5)") + + test("spark2ch") { + val catalystSchema = StructType.fromString( + """{ + | "type": "struct", + | "fields": [ + | {"name": "id", "type": "integer", "nullable": false, "metadata": {}}, + | {"name": "food", "type": "string", "nullable": false, "metadata": {"comment": "food"}}, + | {"name": "price", "type": "decimal(2,1)", "nullable": false, "metadata": {"comment": "price usd"}}, + | {"name": "remark", "type": "string", "nullable": true, "metadata": {}}, + | {"name": "ingredient", "type": {"type": "array", "elementType": "string", "containsNull": true}, "nullable": true, "metadata": {}}, + | {"name": "nutrient", "type": {"type": "map", "keyType": "string", "valueType": "string", "valueContainsNull": true}, "nullable": true, "metadata": {}} + | ] + |} + |""".stripMargin + ) + assert(Seq( + ("id", "Int32", ""), + ("food", "String", " COMMENT 'food'"), + ("price", "Decimal(2, 1)", " COMMENT 'price usd'"), + ("remark", "Nullable(String)", ""), + ("ingredient", "Array(Nullable(String))", ""), + ("nutrient", "Map(String, Nullable(String))", "") + ) == toClickHouseSchema(catalystSchema)) + } +} diff --git a/spark-4.0/examples/README.md b/spark-4.0/examples/README.md new file mode 100644 index 00000000..a3a949ac --- /dev/null +++ b/spark-4.0/examples/README.md @@ -0,0 +1,179 @@ +# Spark 4.0 ClickHouse Connector Examples + +This directory contains example applications for debugging and testing the ClickHouse connector with Spark 4.0. + +## Prerequisites + +1. **ClickHouse Server Running** + ```bash + # Using Docker + docker run -d --name clickhouse-server \ + -p 8123:8123 -p 9000:9000 \ + --ulimit nofile=262144:262144 \ + clickhouse/clickhouse-server + ``` + +2. **Build the Connector** + ```bash + cd /Users/shimonsteinitz/Projects/spark-clickhouse-connector + ./gradlew -Dspark_binary_version=4.0 -Dscala_binary_version=2.13 :clickhouse-spark-4.0_2.13:build + ``` + +## Running Examples in IDE (for Debugging) + +### IntelliJ IDEA / VS Code with Metals + +1. **Import the project** as a Gradle project + +2. **Set up Run Configuration**: + - Main class: `examples.StreamingRateExample` or `examples.SimpleBatchExample` + - VM options: + ``` + -Dspark_binary_version=4.0 + -Dscala_binary_version=2.13 + ``` + - Working directory: `spark-4.0/examples` + - Classpath: Include `clickhouse-spark-4.0_2.13` module + +3. **Set Breakpoints** in connector code: + - Write path: `com.clickhouse.spark.write.ClickHouseWriter` + - Read path: `com.clickhouse.spark.read.ClickHouseReader` + - Catalog operations: `com.clickhouse.spark.ClickHouseCatalog` + +4. **Run in Debug Mode** and step through the connector code + +## Examples + +### 1. SimpleBatchExample + +A straightforward batch processing example that: +- Creates sample employee data +- Writes to ClickHouse +- Reads back and performs aggregations + +**Good for debugging**: +- Table creation logic +- Batch write operations +- Read operations +- Schema inference + +**Run**: +```bash +spark-submit \ + --class examples.SimpleBatchExample \ + --master local[*] \ + --jars clickhouse-spark-runtime-4.0_2.13.jar \ + examples/SimpleBatchExample.scala +``` + +### 2. StreamingRateExample + +A streaming application that: +- Uses Spark's rate source (generates synthetic data) +- Enriches data with multiple columns +- Writes to ClickHouse in micro-batches every 5 seconds +- Generates 10 rows per second + +**Good for debugging**: +- Streaming write operations +- Micro-batch processing +- Continuous data ingestion +- Performance under load + +**Run**: +```bash +spark-submit \ + --class examples.StreamingRateExample \ + --master local[*] \ + --jars clickhouse-spark-runtime-4.0_2.13.jar \ + examples/StreamingRateExample.scala +``` + +**Monitor the stream**: +```sql +-- In ClickHouse client +SELECT count(*) FROM default.streaming_events; + +SELECT + event_type, + count(*) as cnt, + avg(metric_value) as avg_metric +FROM default.streaming_events +GROUP BY event_type; + +SELECT + toStartOfMinute(event_time) as minute, + count(*) as events_per_minute +FROM default.streaming_events +GROUP BY minute +ORDER BY minute DESC +LIMIT 10; +``` + +## Debugging Tips + +### Enable Debug Logging + +Add to your SparkSession configuration: +```scala +.config("spark.sql.catalog.clickhouse.option.log.level", "DEBUG") +``` + +Or set log level programmatically: +```scala +spark.sparkContext.setLogLevel("DEBUG") +``` + +### Useful ClickHouse Queries + +```sql +-- Check table structure +DESCRIBE TABLE default.streaming_events; + +-- Check table engine and settings +SHOW CREATE TABLE default.streaming_events; + +-- Monitor inserts +SELECT + table, + sum(rows) as total_rows, + sum(bytes) as total_bytes +FROM system.parts +WHERE database = 'default' AND table IN ('streaming_events', 'employees') +GROUP BY table; + +-- Check recent parts +SELECT + partition, + name, + rows, + bytes_on_disk, + modification_time +FROM system.parts +WHERE database = 'default' AND table = 'streaming_events' +ORDER BY modification_time DESC +LIMIT 10; +``` + +## Troubleshooting + +### Connection Issues + +If you see connection errors: +```scala +// Verify ClickHouse is accessible +curl http://localhost:8123/ping +``` + +### Clean Up + +```sql +-- Drop tables +DROP TABLE IF EXISTS default.streaming_events; +DROP TABLE IF EXISTS default.employees; +``` + +```bash +# Remove checkpoint directory +rm -rf /tmp/clickhouse-streaming-checkpoint +``` diff --git a/spark-4.0/examples/build.gradle b/spark-4.0/examples/build.gradle new file mode 100644 index 00000000..27ccc5ac --- /dev/null +++ b/spark-4.0/examples/build.gradle @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id "scala" + id "idea" +} + +dependencies { + implementation "org.scala-lang:scala-library:${scala_version}" + + // Spark APIs are needed at runtime when running from IDE/Gradle + implementation "org.apache.spark:spark-sql_${scala_binary_version}:${spark_40_version}" + implementation "org.apache.spark:spark-streaming_${scala_binary_version}:${spark_40_version}" + + // Use connector project classes, and include shaded runtime on classpath + implementation project(":clickhouse-spark-4.0_${scala_binary_version}") + runtimeOnly project(":clickhouse-spark-runtime-4.0_${scala_binary_version}") +} + +sourceCompatibility = JavaVersion.VERSION_17 + +tasks.withType(ScalaCompile).configureEach { + scalaCompileOptions.additionalParameters = [ + "-deprecation", + "-feature" + ] +} + +// Application convenience tasks +tasks.register("runStreaming", JavaExec) { + group = "application" + description = "Run StreamingRateExample" + classpath = sourceSets.main.runtimeClasspath + mainClass = "examples.StreamingRateExample" + jvmArgs "-Dspark_binary_version=4.0", "-Dscala_binary_version=2.13" +} + +tasks.register("runBatch", JavaExec) { + group = "application" + description = "Run SimpleBatchExample" + classpath = sourceSets.main.runtimeClasspath + mainClass = "examples.SimpleBatchExample" + jvmArgs "-Dspark_binary_version=4.0", "-Dscala_binary_version=2.13" +} diff --git a/spark-4.0/examples/build.sbt b/spark-4.0/examples/build.sbt new file mode 100644 index 00000000..79845f7e --- /dev/null +++ b/spark-4.0/examples/build.sbt @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +name := "clickhouse-spark-examples" + +version := "1.0" + +scalaVersion := "2.13.8" + +val sparkVersion = "4.0.1" + +libraryDependencies ++= Seq( + "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", + "org.apache.spark" %% "spark-streaming" % sparkVersion % "provided" +) + +// For local development, include the connector from the parent project +unmanagedClasspath in Compile ++= { + val base = baseDirectory.value / ".." / "clickhouse-spark" / "build" / "classes" + Seq( + Attributed.blank(base / "scala" / "main"), + Attributed.blank(base / "java" / "main") + ) +} diff --git a/spark-4.0/examples/src/main/scala/examples/SimpleBatchExample.scala b/spark-4.0/examples/src/main/scala/examples/SimpleBatchExample.scala new file mode 100644 index 00000000..82672c8f --- /dev/null +++ b/spark-4.0/examples/src/main/scala/examples/SimpleBatchExample.scala @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package examples + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions._ + +/** + * Simple batch example for debugging the ClickHouse connector. + * This example creates sample data and writes it to ClickHouse. + * + * Usage: + * 1. Start ClickHouse + * 2. Set breakpoints in connector code + * 3. Run this application in debug mode + */ +object SimpleBatchExample { + + def main(args: Array[String]): Unit = { + // Read connection parameters from environment or use defaults + val host = sys.env.getOrElse("CH_HOST", "localhost") + val protocol = sys.env.getOrElse("CH_PROTOCOL", "http") + val port = sys.env.getOrElse("CH_PORT", "8123") + val user = sys.env.getOrElse("CH_USER", "default") + val password = sys.env.getOrElse("CH_PASSWORD", "") + val database = sys.env.getOrElse("CH_DATABASE", "default") + + val spark = SparkSession.builder() + .appName("ClickHouse Simple Batch Example") + .master("local[*]") + .config("spark.sql.catalog.clickhouse", "com.clickhouse.spark.ClickHouseCatalog") + .config("spark.sql.catalog.clickhouse.host", host) + .config("spark.sql.catalog.clickhouse.protocol", protocol) + .config("spark.sql.catalog.clickhouse.http_port", port) + .config("spark.sql.catalog.clickhouse.user", user) + .config("spark.sql.catalog.clickhouse.password", password) + .config("spark.sql.catalog.clickhouse.database", database) + .config("spark.sql.catalog.clickhouse.option.ssl", (protocol == "https").toString) + .config( + "spark.executor.extraJavaOptions", + "--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED" + ) + .getOrCreate() + + spark.sparkContext.setLogLevel("WARN") + + println("=" * 80) + println("ClickHouse Simple Batch Example") + println(s"Connecting to: $protocol://$host:$port") + println("=" * 80) + + // Create sample data + import spark.implicits._ + val data = Seq( + (1L, "Alice", 25, "Engineering", 75000.0, java.sql.Timestamp.valueOf("2024-01-15 10:30:00")), + (2L, "Bob", 30, "Sales", 65000.0, java.sql.Timestamp.valueOf("2024-01-15 11:00:00")), + (3L, "Charlie", 35, "Engineering", 85000.0, java.sql.Timestamp.valueOf("2024-01-15 11:30:00")), + (4L, "Diana", 28, "Marketing", 70000.0, java.sql.Timestamp.valueOf("2024-01-15 12:00:00")), + (5L, "Eve", 32, "Engineering", 90000.0, java.sql.Timestamp.valueOf("2024-01-15 12:30:00")), + (6L, "Frank", 29, "Sales", 68000.0, java.sql.Timestamp.valueOf("2024-01-15 13:00:00")), + (7L, "Grace", 31, "Marketing", 72000.0, java.sql.Timestamp.valueOf("2024-01-15 13:30:00")), + (8L, "Henry", 27, "Engineering", 78000.0, java.sql.Timestamp.valueOf("2024-01-15 14:00:00")), + (9L, "Ivy", 33, "Sales", 71000.0, java.sql.Timestamp.valueOf("2024-01-15 14:30:00")), + (10L, "Jack", 26, "Engineering", 76000.0, java.sql.Timestamp.valueOf("2024-01-15 15:00:00")) + ).toDF("employee_id", "name", "age", "department", "salary", "hire_date") + + println("\nSample data to write:") + data.show(10, truncate = false) + + // Create table + println("\nCreating table...") + spark.sql(""" + CREATE TABLE IF NOT EXISTS clickhouse.default.employees ( + employee_id BIGINT NOT NULL, + name STRING, + age INT, + department STRING, + salary DOUBLE, + hire_date TIMESTAMP + ) USING clickhouse + TBLPROPERTIES ( + engine = 'MergeTree()', + order_by = 'employee_id' + ) + """) + println("✓ Table created") + + // Write data using catalog-aware API - Set breakpoints in connector write code to debug + println("\nWriting data to ClickHouse...") + data.writeTo("clickhouse.default.employees") + .append() + println("✓ Data written successfully") + + // Read data back using catalog table - Set breakpoints in connector read code to debug + println("\nReading data from ClickHouse...") + val result = spark.table("clickhouse.default.employees") + .orderBy("employee_id") + + println("\nData read from ClickHouse:") + result.show(10, truncate = false) + + // Perform aggregation + println("\nAggregation by department:") + result.groupBy("department") + .agg( + count("*").as("employee_count"), + avg("salary").as("avg_salary"), + avg("age").as("avg_age") + ) + .orderBy(desc("avg_salary")) + .show(truncate = false) + + println("\n" + "=" * 80) + println("Example completed successfully!") + println("=" * 80) + + spark.stop() + } +} diff --git a/spark-4.0/examples/src/main/scala/examples/StreamingRateExample.scala b/spark-4.0/examples/src/main/scala/examples/StreamingRateExample.scala new file mode 100644 index 00000000..6009745a --- /dev/null +++ b/spark-4.0/examples/src/main/scala/examples/StreamingRateExample.scala @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package examples + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.types._ + +/** + * Streaming example using rate source to write to ClickHouse. + * This example generates synthetic data and writes it to ClickHouse in micro-batches. + * + * Usage: + * 1. Start ClickHouse (e.g., via Docker) + * 2. Run this application + * 3. Set breakpoints in the connector code to debug + * + * The rate source generates rows with: + * - timestamp: event time + * - value: monotonically increasing long + */ +object StreamingRateExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder() + .appName("ClickHouse Streaming Rate Example") + .master("local[*]") + .config("spark.sql.catalog.clickhouse", "com.clickhouse.spark.ClickHouseCatalog") + .config("spark.sql.catalog.clickhouse.host", "localhost") + .config("spark.sql.catalog.clickhouse.protocol", "http") + .config("spark.sql.catalog.clickhouse.http_port", "8123") + .config("spark.sql.catalog.clickhouse.user", "default") + .config("spark.sql.catalog.clickhouse.password", "") + .config("spark.sql.catalog.clickhouse.database", "default") + .config("spark.sql.catalog.clickhouse.option.ssl", "true") + .getOrCreate() + + spark.sparkContext.setLogLevel("WARN") + + println("=" * 80) + println("Starting ClickHouse Streaming Rate Example") + println("=" * 80) + + // Create the target table in ClickHouse + createClickHouseTable(spark) + + // Create a streaming DataFrame using rate source + // Generates 10 rows per second + val rateStream = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "0s") + .option("numPartitions", "2") + .load() + + // Transform the data to create a richer schema + val enrichedStream = rateStream + .withColumn("event_id", col("value")) + .withColumn("event_time", col("timestamp")) + .withColumn( + "event_type", + when(col("value") % 3 === 0, "type_a") + .when(col("value") % 3 === 1, "type_b") + .otherwise("type_c") + ) + .withColumn("metric_value", (rand() * 100).cast("double")) + .withColumn( + "status", + when(col("value") % 5 === 0, "completed") + .when(col("value") % 5 === 1, "pending") + .when(col("value") % 5 === 2, "failed") + .when(col("value") % 5 === 3, "processing") + .otherwise("queued") + ) + .withColumn("user_id", (col("value") % 1000).cast("int")) + .withColumn("session_id", concat(lit("session_"), (col("value") / 100).cast("int"))) + .withColumn( + "metadata", + to_json(struct( + lit("source").as("source_system"), + lit("v1.0").as("version"), + col("value").as("sequence_number") + )) + ) + .select( + "event_id", + "event_time", + "event_type", + "metric_value", + "status", + "user_id", + "session_id", + "metadata" + ) + + // Write to ClickHouse using foreachBatch for better control and debugging + val query = enrichedStream.writeStream + .foreachBatch { (batchDF: org.apache.spark.sql.DataFrame, batchId: Long) => + println(s"\n${"=" * 80}") + println(s"Processing Batch: $batchId") + println(s"Batch Size: ${batchDF.count()} rows") + println(s"${"=" * 80}") + + // Show sample data + println("\nSample data from this batch:") + batchDF.show(5, truncate = false) + + // Write to ClickHouse via catalog-aware V2 writer + // This avoids relying on a DataSource short name (clickhouse.DefaultSource) + batchDF.writeTo("clickhouse.default.streaming_events") + .option("write_format", "arrow") + .option("compression_codec", "none") + .append() + + println(s"✓ Batch $batchId written successfully to ClickHouse\n") + } + .trigger(Trigger.ProcessingTime("5 seconds")) + .option("checkpointLocation", "/tmp/clickhouse-streaming-checkpoint") + .start() + + println("\n" + "=" * 80) + println("Streaming query started. Press Ctrl+C to stop.") + println("=" * 80) + println("\nYou can query the data in ClickHouse:") + println(" SELECT * FROM default.streaming_events ORDER BY event_time DESC LIMIT 10;") + println(" SELECT event_type, count(*) FROM default.streaming_events GROUP BY event_type;") + println(" SELECT status, avg(metric_value) FROM default.streaming_events GROUP BY status;") + println("=" * 80 + "\n") + + // Wait for the query to terminate + query.awaitTermination() + } + + private def createClickHouseTable(spark: SparkSession): Unit = { + println("\nCreating ClickHouse table if not exists...") + + try { + spark.sql(""" + CREATE TABLE IF NOT EXISTS clickhouse.default.streaming_events ( + event_id BIGINT, + event_time TIMESTAMP NOT NULL, + event_type STRING, + metric_value DOUBLE, + status STRING, + user_id INT, + session_id STRING, + metadata STRING + ) + TBLPROPERTIES ( + engine = 'MergeTree()', + order_by = 'event_time', + partition_by = 'toYYYYMMDD(event_time)', + settings.index_granularity = 8192 + ) + """) + println("✓ Table created successfully or already exists\n") + } catch { + case e: Exception => + println(s"Warning: Could not create table: ${e.getMessage}") + println("Make sure ClickHouse is running and accessible\n") + } + } +}