I need to replace the last two values in the value
column of a pl.DataFrame
with zeros, whereby I need to group_by
the symbol
column.
import polars as pl
df = pl.DataFrame(
{"symbol": [*["A"] * 4, *["B"] * 4], "value": range(8)}
)
shape: (8, 2)
┌────────┬───────┐
│ symbol ┆ value │
│ --- ┆ --- │
│ str ┆ i64 │
╞════════╪═══════╡
│ A ┆ 0 │
│ A ┆ 1 │
│ A ┆ 2 │
│ A ┆ 3 │
│ B ┆ 4 │
│ B ┆ 5 │
│ B ┆ 6 │
│ B ┆ 7 │
└────────┴───────┘
Here is my expected outcome:
shape: (8, 2)
┌────────┬───────┐
│ symbol ┆ value │
│ --- ┆ --- │
│ str ┆ i64 │
╞════════╪═══════╡
│ A ┆ 0 │
│ A ┆ 1 │
│ A ┆ 0 │<-- replaced
│ A ┆ 0 │<-- replaced
│ B ┆ 4 │
│ B ┆ 5 │
│ B ┆ 0 │<-- replaced
│ B ┆ 0 │<-- replaced
└────────┴───────┘
I need to replace the last two values in the value
column of a pl.DataFrame
with zeros, whereby I need to group_by
the symbol
column.
import polars as pl
df = pl.DataFrame(
{"symbol": [*["A"] * 4, *["B"] * 4], "value": range(8)}
)
shape: (8, 2)
┌────────┬───────┐
│ symbol ┆ value │
│ --- ┆ --- │
│ str ┆ i64 │
╞════════╪═══════╡
│ A ┆ 0 │
│ A ┆ 1 │
│ A ┆ 2 │
│ A ┆ 3 │
│ B ┆ 4 │
│ B ┆ 5 │
│ B ┆ 6 │
│ B ┆ 7 │
└────────┴───────┘
Here is my expected outcome:
shape: (8, 2)
┌────────┬───────┐
│ symbol ┆ value │
│ --- ┆ --- │
│ str ┆ i64 │
╞════════╪═══════╡
│ A ┆ 0 │
│ A ┆ 1 │
│ A ┆ 0 │<-- replaced
│ A ┆ 0 │<-- replaced
│ B ┆ 4 │
│ B ┆ 5 │
│ B ┆ 0 │<-- replaced
│ B ┆ 0 │<-- replaced
└────────┴───────┘
Share
Improve this question
asked Nov 20, 2024 at 15:24
AndiAndi
4,8995 gold badges33 silver badges63 bronze badges
2 Answers
Reset to default 1You can use
pl.Expr.head()
withpl.len()
to get data without last two rows.pl.Expr.append()
andpl.repeat()
to pad it with zeroes.
df.with_columns(
pl.col.value.head(pl.len() - 2).append(pl.repeat(0, 2))
.over("symbol")
)
shape: (8, 2)
┌────────┬───────┐
│ symbol ┆ value │
│ --- ┆ --- │
│ str ┆ i64 │
╞════════╪═══════╡
│ A ┆ 0 │
│ A ┆ 1 │
│ A ┆ 0 │
│ A ┆ 0 │
│ B ┆ 4 │
│ B ┆ 5 │
│ B ┆ 0 │
│ B ┆ 0 │
└────────┴───────┘
Alternatively, you can use
pl.when()
to create conditional column.pl.int_range()
withpl.len()
to affect only firstn - 2
rows.
df.with_columns(
pl.when(pl.int_range(pl.len()) < pl.len() - 2).then(pl.col.value)
.otherwise(0)
.over("symbol")
)
shape: (8, 2)
┌────────┬───────┐
│ symbol ┆ value │
│ --- ┆ --- │
│ str ┆ i64 │
╞════════╪═══════╡
│ A ┆ 0 │
│ A ┆ 1 │
│ A ┆ 0 │
│ A ┆ 0 │
│ B ┆ 4 │
│ B ┆ 5 │
│ B ┆ 0 │
│ B ┆ 0 │
└────────┴───────┘
You can use .is_last_distinct()
and .shift()
df.with_columns(
pl.when(
pl.any_horizontal(
pl.col("symbol").is_last_distinct(),
pl.col("symbol").shift(-1).over("symbol").is_last_distinct()
).not_()
)
.then(pl.col("value"))
.otherwise(0)
)
shape: (8, 2)
┌────────┬───────┐
│ symbol ┆ value │
│ --- ┆ --- │
│ str ┆ i64 │
╞════════╪═══════╡
│ A ┆ 0 │
│ A ┆ 1 │
│ A ┆ 0 │
│ A ┆ 0 │
│ B ┆ 4 │
│ B ┆ 5 │
│ B ┆ 0 │
│ B ┆ 0 │
└────────┴───────┘