base_idx:Optional[int]# If it is a Tensor, what the dynamic dims are (otherwise is None)dynamic_dims:Optional[Set[int]]
+ # requires_grad
+ requires_grad:bool# This class tells us info about user inputs.
@@ -781,6 +783,7 @@
Source code for torch._functorch.aot_autograd
is_leaf:boolmutates_data:boolmutates_metadata:bool
+ requires_grad:bool# This class encapsulates all aliasing + mutation info we need about the forward graph
@@ -1152,7 +1155,8 @@
Source code for torch._functorch.aot_autograd
input_info.append(InputAliasInfo(is_leaf=isinstance(arg,torch.Tensor)andsafe_is_leaf(arg),mutates_data=mutates_data,
- mutates_metadata=mutates_metadata
+ mutates_metadata=mutates_metadata,
+ requires_grad=isinstance(f_arg,torch.Tensor)andf_arg.requires_grad))# If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediiate,
@@ -1294,11 +1298,10 @@
Source code for torch._functorch.aot_autograd
raw_type=type(o),base_idx=base_idx,dynamic_dims=dynamic_dims,
+ requires_grad=isinstance(o,torch.Tensor)ando.requires_grad)output_info.append(out_info)
- output_requires_grad_info.append(
- isinstance(o,torch.Tensor)ando.requires_grad
- )
+ output_requires_grad_info.append(out_info.requires_grad)# Our autograd.Function.forward returns both mutated inputs and outputs,# so we need grad info on all of them.
@@ -1313,13 +1316,14 @@
Source code for torch._functorch.aot_autograd
f_input_tangents =[inpforinp,infoinzip(flat_f_args,input_info)
- ifinfo.mutates_data
+ ifinfo.mutates_dataandinfo.requires_grad]f_output_tangents=[oforo,infoinzip(flat_f_outs,output_info)ifinfo.output_typein[OutputType.non_alias,OutputType.unsafe_view_alias,OutputType.custom_function_view]andissubclass(info.raw_type,torch.Tensor)
+ andinfo.requires_grad]# intermediate bases are also included in the backward graphf_tangents=f_input_tangents+f_output_tangents+intermediate_bases
@@ -1370,7 +1374,7 @@
Source code for torch._functorch.aot_autograd
to its parameter/buffer FQN in the original nn.Module.
(3) If there are input mutations, these are represented as extra outputs in the fx GraphModule. We provide a mapping from these
- extra output names to the names of the the actual inputs.
+ extra output names to the names of the actual inputs. (4) The pytree metadata on how to flatten/unflatten inputs and outputs. The corresponding FX GraphModule only accepts and returns pytree-flattened inputs/outputs.
@@ -1510,11 +1514,12 @@
Source code for torch._functorch.aot_autograd
if idxinmeta.mutated_inp_indices:# We only need to bother cloning mutated inputs that participate in autograd.mutated_inp_idx=meta.mutated_inp_indices.index(idx)
- ifmeta.requires_grad_info[mutated_inp_idx]andmeta.input_info[idx].mutates_data:
+ assertmeta.input_info[idx].requires_grad==meta.requires_grad_info[mutated_inp_idx]
+ ifmeta.input_info[idx].requires_gradandmeta.input_info[idx].mutates_data:# Make sure the primal we pass to autograd.grad()# sees the tensor before the mutationreturnt.clone()
- ifmeta.requires_grad_info[mutated_inp_idx]andmeta.input_info[idx].mutates_metadata:
+ ifmeta.input_info[idx]andmeta.input_info[idx].mutates_metadata:# Make sure the primal we pass to autograd.grad()# sees the tensor before the metadata mutationreturnt.view(t.shape)
@@ -1588,7 +1593,7 @@
Source code for torch._functorch.aot_autograd
# Also return a boolean mask specifying which outputs to this function will be used as tangents
mutated_inputs_grad_mask=[
- meta.input_info[meta.mutated_inp_indices[i]].mutates_data
+ meta.input_info[meta.mutated_inp_indices[i]].mutates_dataandmeta.input_info[meta.mutated_inp_indices[i]].requires_gradfor(i,x)inenumerate(mutated_inputs_to_return)]
@@ -1600,6 +1605,7 @@
Source code for torch._functorch.aot_autograd
# Also, only tensor outputs should participate in the backward
# (in particular, Symint outputs in the forward graph shouldn't get tangents)andissubclass(meta.output_info[i].raw_type,torch.Tensor)
+ andmeta.output_info[i].requires_gradfor(i,x)inenumerate(outs)]
@@ -2350,7 +2356,8 @@
mutates_data=Trueiflen(outer_indices)>1elsem.input_info[outer_indices[0]].mutates_data,mutates_metadata=Falseiflen(outer_indices)>1elsem.input_info[outer_indices[0]].mutates_metadata,is_leaf=any_leaf,
+ requires_grad=any(m.input_info[x].requires_gradforxinouter_indices))input_infos.append(inpt_info)# requires_grad_info consists of (mutated_inputs, forward_outputs).# For any mutated inputs that correspond to aliased inputs,# Need to replace them with their mutated synthetic baseifinpt_info.mutates_dataorinpt_info.mutates_metadata:
- mutated_inp_require_grad_info.append(any(m.requires_grad_info[x]forxinouter_indices))
+ forxinouter_indices:
+ assertm.requires_grad_info[x]==m.input_info[x].requires_grad
+ mutated_inp_require_grad_info.append(any(m.input_info[x].requires_gradforxinouter_indices))# Find any inputs that fulfill the following criteria:# (1) They are part of a synthetic base (because they alias another input,
@@ -2426,7 +2436,8 @@
Source code for torch._functorch.aot_autograd
# grab the original requires grad info on the outputs, except the ones from the mutated inputs
num_original_input_data_mutations=len([xforxinm.input_infoifx.mutates_dataorx.mutates_metadata])
- output_grad_info=m.requires_grad_info[num_original_input_data_mutations:]
+ output_grad_info=[x.requires_gradforxinm.output_info]
+ assertoutput_grad_info==m.requires_grad_info[num_original_input_data_mutations:]input_metadata_mutation_grad_info=[outer_args[outer_idx].requires_gradforouter_idxinouter_aliased_arg_idx_with_metadata_mutations]input_metadata_output_info=[
@@ -2435,6 +2446,7 @@
seed,offset=CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)adjusted_flat_args.extend([seed,offset])# We are not clearing flat_args here because
- # 1) There is a check in the the debug compiler at the end
+ # 1) There is a check in the debug compiler at the end# 2) It does not matter as these are fake tensorsiftorch._guards.TracingContext.get():
@@ -3450,13 +3474,22 @@
Source code for torch._functorch.aot_autograd
# invariant: intermediate bases always require gradients, so we don't have to
# consider marking them as non-differentiable.raw_returns_not_including_intermediate_bases=raw_returns[:num_mutated_inputs+num_outputs]
+
+ raw_returns_meta=(
+ [xforxinCompiledFunction.metadata.input_infoifx.mutates_dataorx.mutates_metadata]
+ +CompiledFunction.metadata.output_info
+ )
+
+ for(i,x)inenumerate(raw_returns_not_including_intermediate_bases):
+ assertCompiledFunction.metadata.requires_grad_info[i]==raw_returns_meta[i].requires_gradfw_outs_not_requiring_grad=[xfor(i,x)inenumerate(raw_returns_not_including_intermediate_bases)ifisinstance(x,torch.Tensor)
- andnotCompiledFunction.metadata.requires_grad_info[i]
+ andnotraw_returns_meta[i].requires_grad]ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)
+ ctx._materialize_non_diff_grads=Falsefunctionalized_rng_runtime_epilogue(CompiledFunction.metadata,
@@ -3484,56 +3517,46 @@
Source code for torch._functorch.aot_autograd
assert len(flat_args)==expected_grad_outsout_info=CompiledFunction.metadata.output_info
- if(
- CompiledFunction.metadata.num_mutated_metadata_only_inputs>0
- orCompiledFunction.metadata.num_outputs_aliased>0
- ):
- inp_tangents,out_tangents,intermediate_base_tangents=(
- flat_args[0:num_mutated_inps],
- flat_args[num_mutated_inps:num_mutated_inps+CompiledFunction.metadata.num_outputs],
- flat_args[num_mutated_inps+CompiledFunction.metadata.num_outputs:],
- )
- # input_info contains info on *every* input,
- # But in the backward(), we are only given grad outputs for every mutated input.
- # We then need to filter out the grad outputs that correspond to metadata-only mutations.
- mutated_inp_indices=CompiledFunction.metadata.mutated_inp_indices
- input_info=CompiledFunction.metadata.input_info
- assertlen(inp_tangents)==len(mutated_inp_indices)
- inp_tangents_filtered=[
- x
- forx,info_idxinzip(inp_tangents,mutated_inp_indices)
- ifinput_info[info_idx].mutates_data
- ]
- # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
- out_tangents_filtered=[
- x
- forx,infoinzip(out_tangents,out_info)
- ifinfo.output_typein[OutputType.non_alias,OutputType.unsafe_view_alias,OutputType.custom_function_view]
- andissubclass(info.raw_type,torch.Tensor)
- ]
- # intermediate bases always require gradients, and always participate in the backward graph.
- flat_bw_args=itertools.chain(inp_tangents_filtered,out_tangents_filtered,intermediate_base_tangents)
-
- # sanity asserts
- # metadata_only_inps = [
- # x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
- # if not input_info[info_idx].mutates_data
- # ]
- # aliased_outputs = [
- # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
- # assert all(x is None for x in metadata_only_inps)
- # assert all(x is None for x in aliased_outputs)
- else:
- # filter out non-tensor grad_outputs (aka due to ints being returned as outputs in the forward)
- num_mutated_inps=CompiledFunction.metadata.num_mutated_inputs
- mutated_inp_args=flat_args[:num_mutated_inps]ifnum_mutated_inps>0else[]
- user_tangents=flat_args[num_mutated_inps:]
- assertlen(user_tangents)==len(out_info)
- filtered_user_tangents=[xforx,infoinzip(user_tangents,out_info)ifissubclass(info.raw_type,torch.Tensor)]
- flat_bw_args=tuple(mutated_inp_args)+tuple(filtered_user_tangents)
+
+ inp_tangents,out_tangents,intermediate_base_tangents=(
+ flat_args[0:num_mutated_inps],
+ flat_args[num_mutated_inps:num_mutated_inps+CompiledFunction.metadata.num_outputs],
+ flat_args[num_mutated_inps+CompiledFunction.metadata.num_outputs:],
+ )
+ # input_info contains info on *every* input,
+ # But in the backward(), we are only given grad outputs for every mutated input
+ # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad
+ mutated_inp_indices=CompiledFunction.metadata.mutated_inp_indices
+ input_info=CompiledFunction.metadata.input_info
+ assertlen(inp_tangents)==len(mutated_inp_indices)
+ inp_tangents_filtered=[
+ x
+ forx,info_idxinzip(inp_tangents,mutated_inp_indices)
+ ifinput_info[info_idx].mutates_dataandinput_info[info_idx].requires_grad
+ ]
+ # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
+ out_tangents_filtered=[
+ x
+ forx,infoinzip(out_tangents,out_info)
+ ifinfo.output_typein[OutputType.non_alias,OutputType.unsafe_view_alias,OutputType.custom_function_view]
+ andissubclass(info.raw_type,torch.Tensor)
+ andinfo.requires_grad
+ ]
+ # intermediate bases always require gradients, and always participate in the backward graph.
+ flat_bw_args_with_grads=itertools.chain(inp_tangents_filtered,out_tangents_filtered,intermediate_base_tangents)
+
+ # sanity asserts
+ # metadata_only_inps = [
+ # x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
+ # if not input_info[info_idx].mutates_data
+ # ]
+ # aliased_outputs = [
+ # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
+ # assert all(x is None for x in metadata_only_inps)
+ # assert all(x is None for x in aliased_outputs)contiguous_args=[
- t.contiguous()iftorch.is_tensor(t)elsetfortinflat_bw_args
+ t.contiguous()iftorch.is_tensor(t)elsetfortinflat_bw_args_with_grads]rng_args=[]
@@ -3762,7 +3785,12 @@
Source code for torch._functorch.aot_autograd
# This should be rare, and is tricky to get right. When we trace the backward,
# we currently trace with autograd.grad instead of .backward(), which makes it difficult# to ensure that we run autograd all the way through the input **before** it saw the mutation.
- iflen([xforxinfw_metadata.requires_grad_info[:fw_metadata.num_mutated_inputs]ifx])!=0:
+ assert(
+ len([xforxinfw_metadata.input_infoifx.requires_gradandx.mutates_data])
+ ==
+ len([xforxinfw_metadata.requires_grad_info[:fw_metadata.num_mutated_inputs]ifx])
+ )
+ iflen([xforxinfw_metadata.input_infoifx.requires_gradandx.mutates_data])!=0:raiseRuntimeError(f"""\Found a graph input that requires gradients, and received a mutation.This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
diff --git a/nightly/_modules/torch/_functorch/compilers.html b/nightly/_modules/torch/_functorch/compilers.html
index 0c030fdfb..f25325738 100644
--- a/nightly/_modules/torch/_functorch/compilers.html
+++ b/nightly/_modules/torch/_functorch/compilers.html
@@ -215,7 +215,7 @@