We’ve had some discussions on the stochastic nodes (see e.g. this paper for discussion) this week and this is the API we concluded on.
EDIT: WE'VE CHANGED THE API
y1 = f(input1, weights)
z1 = sample(y1)
r1 = get_reward(z1)
y2 = f(input2, weights)
z2 = sample(y2)
r2 = get_reward(z2)
# r1 and r2 is are vectors of scalar rewards
z1.reinforce(r1 + gamma*r2)
z2.reinforce(r2)
(r1 + r2).backward()
cc: @ludc @yuandong-tian