diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 91421a37c64..7c06ef74182 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +73,7 @@ import org.apache.spark.storage.BlockId * An object is spillable (it will be copied to host or disk during OOM) if: * - it has a approxSizeInBytes > 0 * - it is not actively being referenced by the user (call to `materialize`, or aliased) - * - it hasn't already spilled + * - it hasn't already spilled, or is not currently being spilled * - it hasn't been closed * * Aliasing: @@ -140,11 +140,6 @@ import org.apache.spark.storage.BlockId * `materialize` on a handle, the handle must guarantee that it can satisfy that, even if the caller * should wait until a spill happens. This is currently implemented using the handle lock. * - * Note that we hold the handle lock while we are spilling (performing IO). That means that no other - * consumer can access this spillable device handle while it is being spilled, including a second - * thread that is trying to spill and is generating a spill plan, as the handle lock is likely held - * up with IO. We will relax this likely in follow on work. - * * We never hold a store-wide coarse grain lock in the stores when we do IO. */ @@ -160,9 +155,21 @@ trait StoreHandle extends AutoCloseable { * removed on shutdown, or by handle.close, but 0-byte handles are not spillable. */ val approxSizeInBytes: Long + + /** + * This is used to resolve races between closing a handle and spilling. + */ + private[spill] var closed: Boolean = false } trait SpillableHandle extends StoreHandle { + /** + * used to gate when a spill is actively being done so that a second thread won't + * also begin spilling, and a handle won't release the underlying buffer if it's + * closed while spilling + */ + private[spill] var spilling: Boolean = false + /** * Method called to spill this handle. It can be triggered from the spill store, * or directly against the handle. @@ -170,6 +177,14 @@ trait SpillableHandle extends StoreHandle { * This will not free the spilled data. If you would like to free the spill * call `releaseSpilled` * + * This is a thread-safe method. If multiple threads call it at the same time, one + * thread will win and perform the spilling, and the other thread will make + * no modification. + * + * If the disk is full, or a spill failure occurs otherwise (eg. device issues), + * we make no attempt to handle it or restore state, as we expect to be in a non-recoverable + * state at the task/executor level. + * * @note The size returned from this method is only used by the spill framework * to track the approximate size. It should just return `approxSizeInBytes`, as * that's the size that it used when it first started tracking the object. @@ -178,7 +193,14 @@ trait SpillableHandle extends StoreHandle { def spill(): Long /** - * Method used to determine whether a handle tracks an object that could be spilled + * Method used to determine whether a handle tracks an object that could be spilled. + * This is just a primary filtering mechanism, because there is a case where a handle + * will appear spillable according to this check, but then a thread will not be able to + * spill upon an attempt, because another thread has already started spilling the handle. + * However, this is not expected to cause an issue, as it only would come up with multiple + * threads trying to spill with overlapping spill plans. It would not, for instance, + * produce any false negatives. + * * @note At the level of `SpillableHandle`, the only requirement of spillability * is that the size of the handle is > 0. `approxSizeInBytes` is known at * construction, and is immutable. @@ -322,12 +344,25 @@ class SpillableHostBufferHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { - if (disk.isEmpty && host.isDefined) { + val thisThreadSpills = synchronized { + if (disk.isEmpty && host.isDefined && !spilling) { + spilling = true + // incRefCount here so that if close() is called + // while we are spilling, we will prevent the buffer being freed + host.get.incRefCount() + true + } else { + false + } + } + val spilled = if (thisThreadSpills) { + withResource(host.get) { buf => withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => val outputChannel = diskHandleBuilder.getChannel + // the spill IO is non-blocking as it won't impact dev or host directly + // instead we "atomically" swap the buffers below once they are ready GpuTaskMetrics.get.spillToDiskTime { - val iter = new HostByteBufferIterator(host.get) + val iter = new HostByteBufferIterator(buf) iter.foreach { bb => try { while (bb.hasRemaining) { @@ -338,19 +373,29 @@ class SpillableHostBufferHandle private ( } } } - disk = Some(diskHandleBuilder.build) - sizeInBytes + var staging: Option[DiskHandle] = Some(diskHandleBuilder.build) + synchronized { + spilling = false + if (closed) { + staging.foreach(_.close()) + staging = None + doClose() + } else { + disk = staging + } + releaseHostResource() + } } - } else { - 0L } + sizeInBytes + } else { + 0 } - releaseHostResource() spilled } } - override def close(): Unit = { + private def doClose(): Unit = { releaseHostResource() synchronized { disk.foreach(_.close()) @@ -358,6 +403,15 @@ class SpillableHostBufferHandle private ( } } + override def close(): Unit = { + synchronized { + closed = true + } + if (!spilling) { + doClose() + } + } + private[spill] def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = { var hostBuffer: HostMemoryBuffer = null var diskHandle: DiskHandle = null @@ -456,24 +510,57 @@ class SpillableDeviceBufferHandle private ( if (!spillable) { 0L } else { - synchronized { - if (host.isEmpty && dev.isDefined) { - host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get)) - sizeInBytes + val thisThreadSpills = synchronized { + if (host.isEmpty && dev.isDefined && !spilling) { + spilling = true + // incRefCount here so that if close() is called + // while we are spilling, we will prevent the buffer being freed + dev.get.incRefCount() + true } else { - 0L + false + } + } + if (thisThreadSpills) { + withResource(dev.get) { buf => + // the spill IO is non-blocking as it won't impact dev or host directly + // instead we "atomically" swap the buffers below once they are ready + var stagingHost: Option[SpillableHostBufferHandle] = + Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(buf)) + synchronized { + spilling = false + if (closed) { + stagingHost.foreach(_.close()) + stagingHost = None + doClose() + } else { + host = stagingHost + } + } } + sizeInBytes + } else { + 0 } } } - override def close(): Unit = { + private def doClose(): Unit = { releaseDeviceResource() synchronized { host.foreach(_.close()) host = None } } + + override def close(): Unit = { + synchronized { + closed = true + } + if (!spilling) { + doClose() + } + } } class SpillableColumnarBatchHandle private ( @@ -531,29 +618,44 @@ class SpillableColumnarBatchHandle private ( if (!spillable) { 0L } else { - synchronized { - if (host.isEmpty && dev.isDefined) { - withChunkedPacker { chunkedPacker => - meta = Some(chunkedPacker.getPackedMeta) - host = Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker)) - } - // We return the size we were created with. This is not the actual size - // of this batch when it is packed, and it is used by the calling code - // to figure out more or less how much did we free in the device. - approxSizeInBytes + val thisThreadSpills = synchronized { + if (host.isEmpty && dev.isDefined && !spilling) { + spilling = true + GpuColumnVector.incRefCounts(dev.get) + true } else { - 0L + false } } + if (thisThreadSpills) { + withChunkedPacker(dev.get) { chunkedPacker => + meta = Some(chunkedPacker.getPackedMeta) + var staging: Option[SpillableHostBufferHandle] = + Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker)) + synchronized { + spilling = false + if (closed) { + staging.foreach(_.close()) + staging = None + doClose() + } else { + host = staging + } + } + } + // We return the size we were created with. This is not the actual size + // of this batch when it is packed, and it is used by the calling code + // to figure out more or less how much did we free in the device. + approxSizeInBytes + } else { + 0L + } } } - private def withChunkedPacker[T](body: ChunkedPacker => T): T = { - val tbl = synchronized { - if (dev.isEmpty) { - throw new IllegalStateException("cannot get copier without a batch") - } - GpuColumnVector.from(dev.get) + private def withChunkedPacker[T](batchToPack: ColumnarBatch)(body: ChunkedPacker => T): T = { + val tbl = withResource(batchToPack) { _ => + GpuColumnVector.from(batchToPack) } withResource(tbl) { _ => withResource(new ChunkedPacker(tbl, SpillFramework.chunkedPackBounceBufferPool)) { packer => @@ -562,13 +664,22 @@ class SpillableColumnarBatchHandle private ( } } - override def close(): Unit = { + private def doClose(): Unit = { releaseDeviceResource() synchronized { host.foreach(_.close()) host = None } } + + override def close(): Unit = { + synchronized { + closed = true + } + if (!spilling) { + doClose() + } + } } object SpillableColumnarBatchFromBufferHandle { @@ -654,27 +765,56 @@ class SpillableColumnarBatchFromBufferHandle private ( if (!spillable) { 0 } else { - synchronized { - if (host.isEmpty && dev.isDefined) { - val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuColumnVectorFromBuffer] + val thisThreadSpills = synchronized { + if (host.isEmpty && dev.isDefined && !spilling) { + spilling = true + GpuColumnVector.incRefCounts(dev.get) + true + } else { + false + } + } + if (thisThreadSpills) { + withResource(dev.get) { cb => + val cvFromBuffer = cb.column(0).asInstanceOf[GpuColumnVectorFromBuffer] meta = Some(cvFromBuffer.getTableMeta) - host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( - cvFromBuffer.getBuffer)) + var staging: Option[SpillableHostBufferHandle] = + Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( + cvFromBuffer.getBuffer)) + synchronized { + spilling = false + if (closed) { + doClose() + staging.foreach(_.close()) + staging = None + } else { + host = staging + } + } sizeInBytes - } else { - 0L } + } else { + 0L } } } - override def close(): Unit = { + private def doClose(): Unit = { releaseDeviceResource() synchronized { host.foreach(_.close()) host = None } } + + override def close(): Unit = { + synchronized { + closed = true + } + if (!spilling) { + doClose() + } + } } object SpillableCompressedColumnarBatchHandle { @@ -738,21 +878,40 @@ class SpillableCompressedColumnarBatchHandle private ( if (!spillable) { 0L } else { - synchronized { - if (host.isEmpty && dev.isDefined) { - val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuCompressedColumnVector] + val thisThreadSpills = synchronized { + if (host.isEmpty && dev.isDefined && !spilling) { + spilling = true + GpuCompressedColumnVector.incRefCounts(dev.get) + true + } else { + false + } + } + if (thisThreadSpills) { + withResource(dev.get) { cb => + val cvFromBuffer = cb.column(0).asInstanceOf[GpuCompressedColumnVector] meta = Some(cvFromBuffer.getTableMeta) - host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( - cvFromBuffer.getTableBuffer)) + var staging: Option[SpillableHostBufferHandle] = + Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( + cvFromBuffer.getTableBuffer)) + synchronized { + spilling = false + if (closed) { + doClose() + staging = None + } else { + host = staging + } + } compressedSizeInBytes - } else { - 0L } + } else { + 0L } } } - override def close(): Unit = { + private def doClose(): Unit = { releaseDeviceResource() synchronized { host.foreach(_.close()) @@ -760,6 +919,15 @@ class SpillableCompressedColumnarBatchHandle private ( meta = None } } + + override def close(): Unit = { + synchronized { + closed = true + } + if (!spilling) { + doClose() + } + } } object SpillableHostColumnarBatchHandle { @@ -822,33 +990,61 @@ class SpillableHostColumnarBatchHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { - if (disk.isEmpty && host.isDefined) { + val thisThreadSpills = synchronized { + if (disk.isEmpty && host.isDefined && !spilling) { + spilling = true + RapidsHostColumnVector.incRefCounts(host.get) + true + } else { + false + } + } + val bytesSpilled = if (thisThreadSpills) { + withResource(host.get) { cb => withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => GpuTaskMetrics.get.spillToDiskTime { val dos = diskHandleBuilder.getDataOutputStream - val columns = RapidsHostColumnVector.extractBases(host.get) - JCudfSerialization.writeToStream(columns, dos, 0, host.get.numRows()) + val columns = RapidsHostColumnVector.extractBases(cb) + JCudfSerialization.writeToStream(columns, dos, 0, cb.numRows()) + } + var staging: Option[DiskHandle] = Some(diskHandleBuilder.build) + synchronized { + spilling = false + if (closed) { + doClose() + staging.foreach(_.close()) + staging = None + } else { + disk = staging + } } - disk = Some(diskHandleBuilder.build) + releaseHostResource() approxSizeInBytes } - } else { - 0L } + } else { + 0L } - releaseHostResource() - spilled + bytesSpilled } } - override def close(): Unit = { + private def doClose(): Unit = { releaseHostResource() synchronized { disk.foreach(_.close()) disk = None } } + + override def close(): Unit = { + synchronized { + closed = true + } + if (!spilling) { + doClose() + } + } } object DiskHandle { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index 31377695fe4..2e0a3afa26d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1102,4 +1102,86 @@ class SpillFrameworkSuite testBufferFileDeletion(canShareDiskPaths = true) } + def testCloseWhileSpilling[T <: SpillableHandle](handle: T, store: SpillableStore[T], + sleepBeforeCloseNanos: Long): Unit = { + assert(handle.spillable) + assertResult(1)(store.numHandles) + val t1 = new Thread (() => { + // cannot assert how much is spills because it depends on whether the handle + // is already closed or not and we're trying to force both conditions + // in this test to show that it handles potential races correctly + store.spill(handle.approxSizeInBytes) + }) + t1.start() + + // we observed that the race will typically trigger if sleeping between 0.1 and 1 millis + Thread.sleep(sleepBeforeCloseNanos / 1000000L, (sleepBeforeCloseNanos % 1000000L).toInt) + handle.close() + t1.join() + assertResult(0)(store.numHandles) + } + + // This is a small monte carlo simulation where we test overlaying + // closing buffers and spilling at difference delay points to tease out possible + // race conditions. There's only one param/variable in the simulation, but it could + // be extended to N params if needed + def monteCarlo(oneIteration: Long => Unit): Unit = { + for (i <- 1L to 10L) { + val nanos: Long = i * 100 * 1000 + oneIteration(nanos) + } + } + + test("a non-contiguous table close while spilling") { + monteCarlo { sleepBeforeCloseNanos => + val (tbl, dataTypes) = buildTable() + val handle = SpillableColumnarBatchHandle(tbl, dataTypes) + testCloseWhileSpilling(handle, SpillFramework.stores.deviceStore, sleepBeforeCloseNanos) + } + } + + test("a device buffer close while spilling") { + monteCarlo { sleepBeforeCloseNanos => + val (ct, _) = buildContiguousTable() + // the contract for spillable handles is that they take ownership + // incRefCount to follow that pattern + val buff = ct.getBuffer + buff.incRefCount() + val handle = SpillableDeviceBufferHandle(buff) + ct.close() + testCloseWhileSpilling(handle, SpillFramework.stores.deviceStore, sleepBeforeCloseNanos) + } + } + + test("host columnar batch close while spilling") { + monteCarlo { sleepBeforeCloseNanos => + val (hostCb, _) = buildHostBatch() + val handle = SpillableHostColumnarBatchHandle(hostCb) + testCloseWhileSpilling(handle, SpillFramework.stores.hostStore, sleepBeforeCloseNanos) + } + } + + test("host memory buffer close while spilling") { + monteCarlo { sleepBeforeCloseNanos => + val handle = SpillableHostBufferHandle(HostMemoryBuffer.allocate(1024)) + testCloseWhileSpilling(handle, SpillFramework.stores.hostStore, sleepBeforeCloseNanos) + } + } + + test("cb from buffer handle close while spilling") { + monteCarlo { sleepBeforeCloseNanos => + val (ct, dataTypes) = buildContiguousTable() + val handle = SpillableColumnarBatchFromBufferHandle(ct, dataTypes) + testCloseWhileSpilling(handle, SpillFramework.stores.deviceStore, sleepBeforeCloseNanos) + } + } + + test("compressed cb handle close while spilling") { + monteCarlo { sleepBeforeCloseNanos => + val ct = buildCompressedBatch(0, 1000) + val handle = SpillableCompressedColumnarBatchHandle(ct) + testCloseWhileSpilling(handle, SpillFramework.stores.deviceStore, sleepBeforeCloseNanos) + } + } + }