雑記

AtCoderをPythonでやりきりたい!【heapq】

for分でnumpy回したら計算時間オーバー!

https://atcoder.jp/contests/abc141/tasks/abc141_d

こちらの「D – Powerful Discount Tickets」という問題が解けませんでした。

要約すると、リストの一番でっかい値を1/2し続ける問題です。

私の提出したコードは以下のコードです。

import numpy as np
N, M = map(int, input().split())
items = np.array(list(map(int, input().split())))
for i in range(M):
    items[items.argmax()]/=2
print(round(np.sum(items)))

しかしこれでは処理時間がオーバーしてしまいました。(あと余は切り捨てなので、若干間違ってますが)

その対応のためにheapqを利用してみようと思います。

今回はこちらの記事を参考にさせていただきました。

最大値(最小値)に対する処理を高速化する

heapqを使ってみる

ポイントは3つです。

  1. heapq.heapify(list) で最小値を取得しやすいリストに変換する
  2. heapq.heappush(list,x) 変換したリストにxを追加する
  3. heapq.heappop(a) 最小値を削除し、その値を返す

なので今回の問題においては、すべての数値を負の値にして、最小値に対して、1/2し続ければOKです。

ベンチマークしてみる

サンプルデータ

import numpy as np
np.random.seed(0)
N, M = 100000, 100000
items_origin = np.random.randint(1, 1000000000, N).tolist()

今回の問題でとりうる最大数を用意します。

numpyとfor文で回してみた結果

%%timeit
import numpy as np
items = np.array(items_origin)
for i in range(M):
    items[items.argmax()]/=2
print(round(np.sum(items)))
4.97 s ± 105 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

heapqを使ってみた結果

%%timeit
import heapq
A = [-val for val in items_origin]
heapq.heapify(A)
tmp = 0
for m in range(M):
    div = -heapq.heappop(A)
    div = div // 2
    heapq.heappush(A, -div)
print(-sum(A))
158 ms ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

だいたい300倍くらい速くなってますね。

参考

AtCoderを勉強するときは蟻本がおすすめだそうです。

https://qiita.com/drken/items/e77685614f3c6bf86f44

https://qiita.com/drken/items/2f56925972c1d34e05d8

https://amzn.to/2Nh2X02

ABOUT ME
hirayuki
今年で社会人3年目になります。 日々体当たりで仕事を覚えています。 テーマはIT・教育です。 少しでも技術に親しんでもらえるよう、noteで4コマ漫画も書いています。 https://note.mu/hirayuki