bachelor-project/src/gale_shapley.mpc

160 lines
5.5 KiB
Python

# vim: ft=python
from Compiler import types
from Compiler.util import *
from Compiler.oram import OptimalORAM
from Compiler.library import for_range, do_while, time, if_, print_ln, crash, print_str
from Compiler.gs import OMatrix, OMatrixRow, OStack
class Matchmaker:
"""
Based on Matchmaker from Compiler/gs.py in MP-SPDZ, copyright (c) 2023,
Commonwealth Scientific and Industrial Research Organisation (CSIRO)
ABN 41 687 119 230, published under the BSD 3-Clause Licence
"""
def engage(self, patient, therapist, for_real):
self.wives.access(patient, therapist, for_real)
self.husbands.access(therapist, patient, for_real)
def dump(self, patient, therapist, for_real):
self.wives.delete(patient, for_real)
self.husbands.delete(therapist, for_real)
self.unengaged.append(patient, for_real)
def propose(self, patient, therapist, for_real):
(fiance,), free = self.husbands.read(therapist)
engaged = 1 - free
rank_man = self.t_exps[therapist][patient]
(rank_fiance,), worst_fiance = self.t_exps[therapist].read(engaged*fiance)
leaving = self.int_type(rank_man) < self.int_type(rank_fiance)
if self.M < self.N:
leaving = 1 - (1 - leaving) * (1 - worst_fiance)
print_str('woman: %s, man: %s, fiance: %s, worst fiance: %s, ',
*(x.reveal() for x in (therapist, patient, fiance, worst_fiance)))
print_ln('rank man: %s, rank fiance: %s, engaged: %s, leaving: %s',
*(x.reveal() for x in
(rank_man, rank_fiance, engaged, leaving)))
self.dump(fiance, therapist, engaged * leaving * for_real)
self.engage(patient, therapist, (1 - (engaged * (1 - leaving))) * for_real)
self.unengaged.append(patient, engaged * (1 - leaving) * for_real)
def match(self, n_loops=None):
if n_loops is None or n_loops > self.N * self.M:
loop = do_while
init_rounds = self.N
else:
loop = for_range(n_loops)
init_rounds = n_loops / self.M
self.wives = \
self.oram_type(self.N, entry_size=log2(self.N),
init_rounds=0, value_type=self.basic_type)
self.husbands = \
self.oram_type(self.N, entry_size=log2(self.N),
init_rounds=0, value_type=self.basic_type)
propose = \
self.oram_type(self.N, entry_size=log2(self.N),
init_rounds=0, value_type=self.basic_type)
self.unengaged = OStack(self.N, oram_type=self.oram_type,
int_type=self.int_type)
@for_range(init_rounds)
def f(i):
self.unengaged.append(i)
rounds = types.MemValue(types.regint(0))
@loop
def f(i=None):
rounds.iadd(1)
time()
patient = self.unengaged.pop()
pref = self.int_type(propose[patient])
if self.M < self.N and n_loops is None:
@if_((pref == self.M).reveal())
def f():
print_ln('run out of acceptable women')
crash()
propose[patient] = pref + 1
self.propose(patient, self.p_cases[patient][pref], True)
print_ln('man: %s, pref: %s, left: %s',
*(x.reveal() for x in (patient, pref, self.unengaged.size)))
return types.regint((self.unengaged.size > 0).reveal())
print_ln('%s rounds', rounds)
@for_range(init_rounds)
def f(i):
types.cint(i).print_reg('wife')
self.husbands[i].reveal().print_reg('husb')
def __init__(self, N, p_cases, t_exps, M=1, reverse=False,
oram_type=OptimalORAM, int_type=types.sint):
self.N = N
self.M = M
self.p_cases = p_cases
self.t_exps = t_exps
self.reverse = reverse
self.oram_type = oram_type
self.int_type = int_type
self.basic_type = int_type.basic_type
print('match', N, M)
""" CONSTANTS """
PLAYERS = 3
MATCHING_SIZE = 2
""" Assembling lists """
p_shares = Matrix(rows=PLAYERS, columns=MATCHING_SIZE, value_type=sint)
t_shares = Matrix(rows=PLAYERS, columns=MATCHING_SIZE, value_type=sint)
# Fill data from players into the patient matrix
# The matrix is ordered as m[row:player][col:share]
@for_range(PLAYERS)
def _(i):
@for_range(MATCHING_SIZE)
def _(j):
p_shares[i][j] = sint.get_input_from(i)
# Fill data from players into the therapist matrix
# The matrix is ordered as m[row:player][col:share]
@for_range(PLAYERS)
def _(i):
@for_range(MATCHING_SIZE)
def _(j):
t_shares[i][j] = sint.get_input_from(i)
@for_range(PLAYERS)
def _(i):
@for_range(MATCHING_SIZE)
def _(j):
print_ln('input from player %s: %s', i, p_shares[i][j].reveal())
@for_range(PLAYERS)
def _(i):
@for_range(MATCHING_SIZE)
def _(j):
print_ln('input from player %s: %s', i, t_shares[i][j].reveal())
# Add entire column together to reveal secret-shared input
p_cases = OMatrix(N=MATCHING_SIZE, M=1, oram_type=OptimalORAM, int_type=types.sint)
t_exps = OMatrix(N=MATCHING_SIZE, M=1, oram_type=OptimalORAM, int_type=types.sint)
@for_range(MATCHING_SIZE)
def _(i):
p_val = sum(p_shares.get_column(i))
p_cases[i][0] = p_val
t_val = sum(t_shares.get_column(i))
t_exps[i][0] = t_val
print_ln('p_res: %s', p_val.reveal())
print_ln('t_res: %s', t_val.reveal())
mm = Matchmaker(MATCHING_SIZE, p_cases, t_exps)
mm.match()