最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

apache spark sql - Aggregate across multiple columns of different types when pivoting in pyspark - Stack Overflow

programmeradmin3浏览0评论

I have a melted table of the form:


+------+---------+--------------+------------+--------------+
| time | channel | value_double | value_long | value_string |
+------+---------+--------------+------------+--------------+
|    0 | A       | 1.1          | null       | null         |
|    0 | B       | null         | 1          | null         |
|    0 | C       | null         | null       | "foo"        |
|    1 | A       | 2.1          | null       | null         |
|    1 | B       | null         | 2          | null         |
|    1 | C       | null         | null       | "bar"        |
|    2 | A       | 3.1          | null       | null         |
|    2 | B       | null         | 3          | null         |
|    2 | C       | null         | null       | "foobar"     |
+------+---------+--------------+------------+--------------+

And I'd like to pivot this table to be:

+------+-----+---+----------+
| time | A   | B | C        |
+------+-----+---+----------+
| 0    | 1.1 | 1 | "foo"    |
| 1    | 2.1 | 2 | "bar"    |
| 2    | 3.1 | 3 | "foobar" |
+------+-----+---+----------+

I've got something along the lines of:

df.groupBy("time").pivot("channel").agg(...)

But I'm strugging to fill the agg function to aggregate across the different values. I've tried coalesce but it runs into errors because of the distinct types between the columns.

Some key points:

  • The three value columns have distinct types (double, long, and string)
  • The typing is consistent per channel
  • There is always one and only one value column with data per row

Is this possible with PySpark/SparkSQL?

I have a melted table of the form:


+------+---------+--------------+------------+--------------+
| time | channel | value_double | value_long | value_string |
+------+---------+--------------+------------+--------------+
|    0 | A       | 1.1          | null       | null         |
|    0 | B       | null         | 1          | null         |
|    0 | C       | null         | null       | "foo"        |
|    1 | A       | 2.1          | null       | null         |
|    1 | B       | null         | 2          | null         |
|    1 | C       | null         | null       | "bar"        |
|    2 | A       | 3.1          | null       | null         |
|    2 | B       | null         | 3          | null         |
|    2 | C       | null         | null       | "foobar"     |
+------+---------+--------------+------------+--------------+

And I'd like to pivot this table to be:

+------+-----+---+----------+
| time | A   | B | C        |
+------+-----+---+----------+
| 0    | 1.1 | 1 | "foo"    |
| 1    | 2.1 | 2 | "bar"    |
| 2    | 3.1 | 3 | "foobar" |
+------+-----+---+----------+

I've got something along the lines of:

df.groupBy("time").pivot("channel").agg(...)

But I'm strugging to fill the agg function to aggregate across the different values. I've tried coalesce but it runs into errors because of the distinct types between the columns.

Some key points:

  • The three value columns have distinct types (double, long, and string)
  • The typing is consistent per channel
  • There is always one and only one value column with data per row

Is this possible with PySpark/SparkSQL?

Share Improve this question asked Nov 16, 2024 at 4:33 GolferDudeGolferDude 644 bronze badges 2
  • are channels A, B, C always guaranteed to having non-null values that are all the same type? – Derek O Commented Nov 17, 2024 at 0:46
  • 1 Yes, they are. They will always have one and only one non null value of the same type – GolferDude Commented Nov 17, 2024 at 1:07
Add a comment  | 

4 Answers 4

Reset to default 0

Use coalesce and first functions together.

from pyspark.sql import functions as F
...
df = df.groupBy('time').pivot('channel').agg(F.first(F.coalesce('value_double', 'value_long', 'value_string')))

Breaking down the steps for better understanding :

from pyspark.sql.functions import col, when, coalesce


# Identifying the appropriate value column for each channel:
df = df.withColumn("value", 
                   when(col("channel") == "A", col("value_double"))
                   .when(col("channel") == "B", col("value_long"))
                   .otherwise(col("value_string")))
                  )

# Then, pivot the DataFrame:
df_pivoted = df.groupBy("time").pivot("channel").agg(first("value"))

df_pivoted.show()

EDIT : leveraging dynamic schema inference and data type conversion as the number of channels are high and mapping is not efficient.

from pyspark.sql.functions import col, when, coalesce, to_date, to_timestamp, to_timestamp_ntz


# Define a function to dynamically infer the appropriate data type:
def infer_data_type(value):
    try:
        return float(value)
    except ValueError:
        try:
            return int(value)
        except ValueError:
            try:
                return to_date(value)
            except ValueError:
                try:
                    return to_timestamp(value)
                except ValueError:
                    return to_timestamp_ntz(value)
                except ValueError:
                    return str(value)

# Create a UDF for the data type inference:
infer_type_udf = udf(infer_data_type)

# Pivot the DataFrame, inferring the data type for each value:
df_pivoted = df.groupBy("time").pivot("channel").agg(first(infer_type_udf("value")))

If it is guaranteed that there is one non-null value for each channel and each value you could restructure the DataFrame before pivoting:

df_filtered = []

for value_col in ["value_double","value_long","value_string"]:
    df_filtered.append(df.select("time","channel",value_col).dropna().groupby("time").pivot("channel").agg(F.max(value_col)))

for i, df_curr in enumerate(df_filtered):
    if i == 0:
        df_all = df_curr
    else:
        df_all = df_all.join(df_curr, on=['time'], how='inner')

Result:

+----+---+---+------+
|time|  A|  B|     C|
+----+---+---+------+
|   1|2.1|  2|   bar|
|   2|3.1|  3|foobar|
|   0|1.1|  1|   foo|
+----+---+---+------+

Use colRegex to dynamically identify the value_ like columns, then create a mapping (m) from the value columns

cols = df.select(df.colRegex(r'`value_.*`')).columns
m = F.create_map(*[y for c in cols for y in (F.lit(c), F.col(c))])
df1 = df.select('time', 'channel', m.alias('m'))

# df1.show()
# +----+-------+--------------------------------------------------------------------+
# |time|channel|m                                                                   |
# +----+-------+--------------------------------------------------------------------+
# |0   |A      |{value_double -> 1.1, value_long -> null, value_string -> null}     |
# |0   |B      |{value_double -> null, value_long -> 1.0, value_string -> null}     |
# |0   |C      |{value_double -> null, value_long -> null, value_string -> "foo"}   |
# |1   |A      |{value_double -> 2.1, value_long -> null, value_string -> null}     |
# |1   |B      |{value_double -> null, value_long -> 2.0, value_string -> null}     |
# |1   |C      |{value_double -> null, value_long -> null, value_string -> "bar"}   |
# |2   |A      |{value_double -> 3.1, value_long -> null, value_string -> null}     |
# |2   |B      |{value_double -> null, value_long -> 3.0, value_string -> null}     |
# |2   |C      |{value_double -> null, value_long -> null, value_string -> "foobar"}|

Use map_filter to remove the null key-value pairs

df1 = df1.withColumn('m', F.map_filter('m', lambda k, v: ~F.isnull(v)))

# df1.show()
# +----+-------+--------------------------+
# |time|channel|m                         |
# +----+-------+--------------------------+
# |0   |A      |{value_double -> 1.1}     |
# |0   |B      |{value_long -> 1.0}       |
# |0   |C      |{value_string -> "foo"}   |
# |1   |A      |{value_double -> 2.1}     |
# |1   |B      |{value_long -> 2.0}       |
# |1   |C      |{value_string -> "bar"}   |
# |2   |A      |{value_double -> 3.1}     |
# |2   |B      |{value_long -> 3.0}       |
# |2   |C      |{value_string -> "foobar"}|
# +----+-------+--------------------------+

Pivot the data frame by time and channel

df1 = df1.groupBy('time').pivot('channel').agg(F.first('m'))

# df1.show()
# +----+---------------------+-------------------+--------------------------+
# |time|A                    |B                  |C                         |
# +----+---------------------+-------------------+--------------------------+
# |0   |{value_double -> 1.1}|{value_long -> 1.0}|{value_string -> "foo"}   |
# |1   |{value_double -> 2.1}|{value_long -> 2.0}|{value_string -> "bar"}   |
# |2   |{value_double -> 3.1}|{value_long -> 3.0}|{value_string -> "foobar"}|
# +----+---------------------+-------------------+--------------------------+

Use map_values to extract the value from the mapping

df1 = df1.select('time', *[F.map_values(c)[0].alias(c) for c in df2.columns[1:]])

# df1.show()
# +----+---+---+--------+
# |time|A  |B  |C       |
# +----+---+---+--------+
# |0   |1.1|1.0|"foo"   |
# |1   |2.1|2.0|"bar"   |
# |2   |3.1|3.0|"foobar"|
# +----+---+---+--------+

与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论