#!/usr/bin/env python3

import os
import sys
import platform
import math
from datetime import datetime,timezone

hostname = platform.node()
when = datetime.now(timezone.utc).strftime('%Y%m%d')

with open('../version') as f:
  sortbenchversion = f.read().strip()

labels = []
data = {}
nset = set()

overhead = []
cpucycles_versions = set()
cpucycles_implementations = set()

args = sys.argv[1:]
bits,args = int(args[0]),args[1:]
while len(args) >= 2:
  label,color,args = *args[:2],args[2:]
  fn = f'{label}/bench{bits}-{label}.out'
  if not os.path.exists(fn): continue
  labels.append(label)
  with open(fn) as f:
    for line in f:
      line = line.strip().split()
      if len(line) < 1: continue
      if line[0] == 'cpucycles_version':
        cpucycles_versions.add(line[1])
        continue
      if line[0] == 'cpucycles_implementation':
        cpucycles_implementations.add(line[1])
        continue
      if line[0] == 'overhead':
        for cycles in line[1:]:
          cycles = int(cycles)
          overhead.append(cycles)
        continue
      if not line[0].isdigit(): continue
      n = int(line[0])
      nset.add(n)
      for cycles in line[1:]:
        key = n,label,color
        if key not in data: data[key] = []
        cycles = int(cycles)
        data[key].append(float(cycles))

def mean(S):
  S = list(S)
  return sum(S)/len(S)

def stq(S): # stabilized quartiles; see https://cr.yp.to/papers.html#rsrst
  S = sorted(8*list(S))
  n = len(S)//8
  return mean(S[n:3*n]),mean(S[3*n:5*n]),mean(S[5*n:7*n])

overhead = stq(overhead)
print(f'overhead {overhead[0]:.2f} {overhead[1]:.2f} {overhead[2]:.2f}')
overhead = overhead[1] # second stabilized quartile

print(f'cpucycles_versions {" ".join(sorted(cpucycles_versions))}')
print(f'cpucycles_implementations {" ".join(sorted(cpucycles_implementations))}')

if len(cpucycles_versions) != 1:
  raise Exception('not exactly one cpucycles_version, aborting')
if len(cpucycles_implementations) != 1:
  raise Exception('not exactly one cpucycles_implementation, aborting')

ylist = []

todo = []
for key in sorted(data):
  n = key[0]
  perbyte = 8.0/float(bits*n)
  q1,q2,q3 = stq(data[key])
  q1 = (q1-overhead)*perbyte
  q2 = (q2-overhead)*perbyte
  q3 = (q3-overhead)*perbyte
  if n in (16,1024,65536):
    print(f'{key[0]} {key[1]}: {q1:.2f}, {q2:.2f}, {q3:.2f}')
  todo += [(q2,q1,q3)+key]
  ylist += [q1,q2,q3]

import matplotlib
import matplotlib.pyplot as plt

fig,ax = plt.subplots()
fig.set_size_inches((6,6))
plt.title(f'y = cycles/byte to sort int{bits}[x] on {hostname} ({when}; sortbench-{sortbenchversion})',fontsize='medium')
plt.semilogx(base=2)
plt.semilogy(base=2)
plt.ylim((min(ylist)/1.4,max(ylist)*1.4))
ax.xaxis.set_major_locator(matplotlib.ticker.LogLocator(base=2,numticks=22))
ax.yaxis.set_major_locator(matplotlib.ticker.LogLocator(base=2,numticks=20))
ax.yaxis.set_minor_locator(matplotlib.ticker.LogLocator(base=2,subs=(1.414213562373095,),numticks=20))
ax.yaxis.set_minor_formatter('')
ax.yaxis.grid(True,which='minor')
plt.grid(visible=True)
plt.tight_layout()

handles = {}

widthscale = 10.0/math.log(2)
nwidth = {}
prevn = None
prevlog = None
for n in sorted(nset):
  nwidth[n] = 0.5
  newlog = math.log(n)
  if prevn is not None:
    nwidth[prevn] = widthscale*(newlog-prevlog)
  prevn = n
  prevlog = newlog

for q2,q1,q3,n,label,color in reversed(sorted(todo)):
  capsize = nwidth[n]
  onebar = plt.errorbar([n],[q2],yerr=([q2-q1],[q3-q2]),label=label,marker='x',markersize=1.0,linewidth=0.5,capsize=capsize,markeredgewidth=0.5,color=color)
  if label not in handles and n > 2**10:
    handles[label] = onebar

ax.text(1,min(ylist),'faster',va='bottom',rotation=-90,bbox=dict(boxstyle='rarrow',lw=2,fc='lightgray',ec='gray',alpha=0.8))
plt.legend(handles=[handles[label] for label in labels],loc='lower right',fontsize=8)

plt.savefig(f'plot{bits}.pdf')
plt.savefig(f'plot{bits}.png')
plt.savefig(f'plot{bits}.svg')
