Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate forward and backwad compilation and support higher order derivatives for aot_function #856

Open
wants to merge 12 commits into
base: gh/anjali411/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:

def create_joint_forward_backward(fn):
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
primals: List[Any], cotangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
# Call the forward pass
outs = fn(*primals)
Expand All @@ -68,21 +68,21 @@ def joint_forward_backward(
grad_primals.append(p)

# Get the outputs that need gradients
assert len(tangents) == len(outs)
assert len(cotangents) == len(outs)
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs, tangents):
needed_cotangents = []
for out, cotangent in zip(outs, cotangents):
if isinstance(out, Tensor) and out.requires_grad:
needed_outs.append(out)
needed_tangents.append(tangent)
needed_cotangents.append(cotangent)
backward_out = []
# Call the backwards pass
if grad_primals:
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
allow_unused=True,
grad_outputs=needed_cotangents,
allow_unused=True
)
backward_out_iter = iter(backward_out)
return outs, [
Expand Down Expand Up @@ -140,12 +140,13 @@ def create_aot_autograd_function(
compiled_fw = None
compiled_bw = None
num_outs = None
aot_decompositions = {**aot_autograd_decompositions, **decompositions}

class CompiledFunction(torch.autograd.Function):
@staticmethod
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
nonlocal compiled_fw, num_outs
if compiled_fw is None:
with torch.set_grad_enabled(grad_state):
out = flat_fn(*flat_tensor_args)
Expand All @@ -159,31 +160,83 @@ def forward(ctx, *flat_tensor_args):
num_outs = 1

joint_inputs = (flat_tensor_args, out)
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
# Need it because autograd.Function disables grad in forward
with torch.set_grad_enabled(grad_state):
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
*joint_inputs
)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# print(fw_module.code, bw_module.code)

compiled_fw = fw_compiler(fw_module, flat_tensor_args)
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
if partition_fn is default_partition:
print("ENTERING default_partition")
ctx.num_intermediate = len(fw_outs[num_outs:])
ctx.num_inputs = len(flat_tensor_args)
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
print("fw outs: ", fw_outs, "-------")
ctx.save_for_backward(*to_be_saved)
ctx.fwd_graph = fw_module.code
else:
nonlocal compiled_bw
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
ctx.save_for_backward(*fw_outs[num_outs:])
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
ctx.save_for_backward(*fw_outs[num_outs:])
if partition_fn is default_partition:
with torch.set_grad_enabled(grad_state):
out = flat_fn(*flat_tensor_args)
out = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
)
ctx.num_intermediate = len(fw_outs[num_outs:])
ctx.num_inputs = len(flat_tensor_args)
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
ctx.save_for_backward(*to_be_saved)
else:
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
return tuple(out)
def backward(ctx, *flat_grad_outs):
print(flat_grad_outs)
contiguous_args = [t.contiguous() for t in flat_grad_outs]
if compiled_bw is None:
assert partition_fn is default_partition
with torch.set_grad_enabled(grad_state):
inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs]
fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args)
# assert that the forward graph generated here is the same
# if it's specified that the user might want to calculate double backwards
fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:])
print(fw_module.code)
print(ctx.fwd_graph)
assert fw_module.code == ctx.fwd_graph
func_code = bw_module.code.split('self, ')
# print(func_code[0] + func_code[1])
exec(func_code[0] + func_code[1], globals())
f = create_aot_autograd_function(forward, bw_compiler, bw_compiler, partition_fn, aot_decompositions, grad_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions:

  • Why are we passing forward to create_aot_autograd_function? I would have expected us to pass bw_module.code without the self argument
  • What is the exec for? Are you trying to test this without the create_aot_autograd_function line?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • forward is the name of the function generated by running bw_module.code
  • exec executes the bw_module.code to create a backward function which is the forward for the next pass

# print(bw_module.code, *ctx.saved_tensors, contiguous_args)
# print(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
# print(*ctx.saved_tensors[ctx.num_intermediate:], *contiguous_args)
return f.apply(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
else:
assert not torch.is_grad_enabled()
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
return tuple(out)
# nonlocal compiled_bw
# contiguous_args = [t.contiguous() for t in flat_grad_outs]
# if compiled_bw is None:
# with torch.set_grad_enabled(grad_state):
# fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
# # assert that the forward graph generated here is the same
# # if it's specified that the user might want to calculate double backwards
# fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
# out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
# return tuple(out)

return CompiledFunction

Expand Down
41 changes: 35 additions & 6 deletions test/test_pythonkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,25 +246,54 @@ def f(args, kwargs):

def _outs_and_grads(fn, inps):
outs = fn(*inps)
diff_outs = []
for out in pytree.tree_flatten(outs)[0]:
if isinstance(out, torch.Tensor) and out.requires_grad:
out.sum().backward(retain_graph=True)
grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
for inp in pytree.tree_flatten(inps)[0]:
inp.grad = None
diff_outs.append(out)
def full_reduce(outs):
res = 0
for out in outs:
res=res+out.sum()
return res
print(inps)
grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True)
return outs, grads

def _outs_and_grads_and_grad_grads(fn, inps):
outs = fn(*inps)
diff_outs = []
diff_inps = []
for out in pytree.tree_flatten(outs)[0]:
if isinstance(out, torch.Tensor) and out.requires_grad:
diff_outs.append(out)
for inp in pytree.tree_flatten(inps)[0]:
if isinstance(inp, torch.Tensor) and inp.requires_grad:
diff_inps.append(inp)
def full_reduce(outs):
res = 0
for out in outs:
res=res+out.sum()
return res
grads = torch.autograd.grad(full_reduce(diff_outs), diff_inps, create_graph=True)
print("grads: ", grads)
diff_grads = []
for grad_ in grads:
if isinstance(grad_, torch.Tensor) and grad_.requires_grad:
diff_grads.append(grad_)
grad_grads = torch.autograd.grad(full_reduce(diff_grads), diff_inps)
return outs, grads, grad_grads

class TestAOTAutograd(TestCase):
def verify_aot_autograd(self, f, inp):
if isinstance(f, nn.Module):
compiled_f = aot_module(f, nop)
else:
compiled_f = aot_function(f, nop)
ref_out, ref_grad = _outs_and_grads(f, inp)
test_out, test_grad = _outs_and_grads(compiled_f, inp)
ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp)
test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp)
self.assertEqual(ref_out, test_out)
self.assertEqual(ref_grad, test_grad)
# self.assertEqual(ref_grad_grad, test_grad_grad)

def test_single_output(self):
def f(a, b):
Expand Down