Skip to content

Commit

Permalink
Added batched raycast tests and docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelalonsojr committed Jul 20, 2023
1 parent e241994 commit 039ece6
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 10 deletions.
3 changes: 2 additions & 1 deletion com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to
### Major Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- The minimum supported Unity version was updated to 2022.3. (#)
- Added batched raycast sensor option. (#)

#### ml-agents / ml-agents-envs

Expand Down Expand Up @@ -47,7 +48,7 @@ versioned under `ml-agents-envs` package in the future (#)
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- Added switch to RayPerceptionSensor to allow rays to be ordered left to right. (#26)
- Current alternating order is still the default but will be deprecated.
- Added suppport for enabling/disabling camera object attached to camera sensor in order to improve performance. (#31)
- Added support for enabling/disabling camera object attached to camera sensor in order to improve performance. (#31)

#### ml-agents / ml-agents-envs
- Renaming the path that shadows torch with "mlagents/trainers/torch_entities" and update respective imports (#)
Expand Down
28 changes: 22 additions & 6 deletions com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -400,20 +400,34 @@ public BuiltInSensorType GetBuiltInSensorType()
/// Evaluates the raycasts to be used as part of an observation of an agent.
/// </summary>
/// <param name="input">Input defining the rays that will be cast.</param>
/// <param name="batched">Use batched raycasts.</param>
/// <returns>Output struct containing the raycast results.</returns>
public static RayPerceptionOutput Perceive(RayPerceptionInput input)
public static RayPerceptionOutput Perceive(RayPerceptionInput input, bool batched)
{
RayPerceptionOutput output = new RayPerceptionOutput();
output.RayOutputs = new RayPerceptionOutput.RayOutput[input.Angles.Count];

for (var rayIndex = 0; rayIndex < input.Angles.Count; rayIndex++)
if (batched)
{
output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex);
PerceiveBatchedRays(ref output.RayOutputs, input);
}
else
{
for (var rayIndex = 0; rayIndex < input.Angles.Count; rayIndex++)
{
output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex);
}
}

return output;
}

/// <summary>
/// Evaluate the raycast results of all the rays from the RayPerceptionInput as a batch.
/// </summary>
/// <param name="input"></param>
/// <param name="rayIndex"></param>
/// <returns></returns>
internal static void PerceiveBatchedRays(ref RayPerceptionOutput.RayOutput[] batchedRaycastOutputs, RayPerceptionInput input)
{
var numRays = input.Angles.Count;
Expand Down Expand Up @@ -445,13 +459,15 @@ internal static void PerceiveBatchedRays(ref RayPerceptionOutput.RayOutput[] bat
var queryParameters = QueryParameters.Default;
queryParameters.layerMask = input.LayerMask;

var rayDirectionNormalized = rayDirection.normalized;

if (scaledCastRadius > 0f)
{
spherecastCommands[i] = new SpherecastCommand(startPositionWorld, scaledCastRadius, rayDirection, queryParameters, scaledRayLength);
spherecastCommands[i] = new SpherecastCommand(startPositionWorld, scaledCastRadius, rayDirectionNormalized, queryParameters, scaledRayLength);
}
else
{
raycastCommands[i] = new RaycastCommand(startPositionWorld, rayDirection, queryParameters, scaledRayLength);
raycastCommands[i] = new RaycastCommand(startPositionWorld, rayDirectionNormalized, queryParameters, scaledRayLength);
}

batchedRaycastOutputs[i] = new RayPerceptionOutput.RayOutput
Expand Down Expand Up @@ -494,7 +510,7 @@ internal static void PerceiveBatchedRays(ref RayPerceptionOutput.RayOutput[] bat

// hitFraction = castHit ? (scaledRayLength > 0 ? results[i].distance / scaledRayLength : 0.0f) : 1.0f;
// Debug.Log(results[i].distance);
hitFraction = castHit ? (scaledRayLength > 0 ? results[i].distance : 0.0f) : 1.0f;
hitFraction = castHit ? (scaledRayLength > 0 ? results[i].distance / scaledRayLength : 0.0f) : 1.0f;
hitObject = castHit ? results[i].collider.gameObject : null;

if (castHit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ public void TestStaticPerceive()
{
perception.SphereCastRadius = castRadius;
var castInput = perception.GetRayPerceptionInput();
var castOutput = RayPerceptionSensor.Perceive(castInput);
var castOutput = RayPerceptionSensor.Perceive(castInput, false);

Assert.AreEqual(1, castOutput.RayOutputs.Length);

Expand Down Expand Up @@ -391,7 +391,7 @@ public void TestStaticPerceiveInvalidTags()
// There's no clean way that I can find to check for a defined tag without
// logging an error.
LogAssert.Expect(LogType.Error, "Tag: Bad tag is not defined.");
var castOutput = RayPerceptionSensor.Perceive(castInput);
var castOutput = RayPerceptionSensor.Perceive(castInput, false);

Assert.AreEqual(1, castOutput.RayOutputs.Length);

Expand All @@ -418,7 +418,7 @@ public void TestStaticPerceiveNoTags()
{
perception.SphereCastRadius = castRadius;
var castInput = perception.GetRayPerceptionInput();
var castOutput = RayPerceptionSensor.Perceive(castInput);
var castOutput = RayPerceptionSensor.Perceive(castInput, false);

Assert.AreEqual(1, castOutput.RayOutputs.Length);

Expand All @@ -428,6 +428,97 @@ public void TestStaticPerceiveNoTags()
}
}

[Test]
public void TestStaticBatchedPerceive()
{
SetupScene();
var obj = new GameObject("agent");
var perception = obj.AddComponent<RayPerceptionSensorComponent3D>();

perception.RaysPerDirection = 1; // three rays
perception.MaxRayDegrees = 45;
perception.RayLength = 20;
perception.DetectableTags = new List<string>();
perception.DetectableTags.Add(k_CubeTag);
perception.DetectableTags.Add(k_SphereTag);

var radii = new[] { 0f, .5f };
foreach (var castRadius in radii)
{
perception.SphereCastRadius = castRadius;
var castInput = perception.GetRayPerceptionInput();
var castOutput = RayPerceptionSensor.Perceive(castInput, true);

Assert.AreEqual(3, castOutput.RayOutputs.Length);

// Expected to hit the cube
Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name);
Assert.AreEqual(0, castOutput.RayOutputs[0].HitTagIndex);
}
}

[Test]
public void TestStaticPerceiveBatchedInvalidTags()
{
SetupScene();
var obj = new GameObject("agent");
var perception = obj.AddComponent<RayPerceptionSensorComponent3D>();

perception.RaysPerDirection = 0; // three rays
perception.MaxRayDegrees = 45;
perception.RayLength = 20;
perception.DetectableTags = new List<string>();
perception.DetectableTags.Add("Bad tag");
perception.DetectableTags.Add(null);
perception.DetectableTags.Add("");
perception.DetectableTags.Add(k_CubeTag);

var radii = new[] { 0f, .5f };
foreach (var castRadius in radii)
{
perception.SphereCastRadius = castRadius;
var castInput = perception.GetRayPerceptionInput();

// There's no clean way that I can find to check for a defined tag without
// logging an error.
LogAssert.Expect(LogType.Error, "Tag: Bad tag is not defined.");
var castOutput = RayPerceptionSensor.Perceive(castInput, true);

Assert.AreEqual(1, castOutput.RayOutputs.Length);

// Expected to hit the cube
Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name);
Assert.AreEqual(3, castOutput.RayOutputs[0].HitTagIndex);
}
}

[Test]
public void TestStaticPerceiveBatchedNoTags()
{
SetupScene();
var obj = new GameObject("agent");
var perception = obj.AddComponent<RayPerceptionSensorComponent3D>();

perception.RaysPerDirection = 1; // single ray
perception.MaxRayDegrees = 45;
perception.RayLength = 20;
perception.DetectableTags = null;

var radii = new[] { 0f, .5f };
foreach (var castRadius in radii)
{
perception.SphereCastRadius = castRadius;
var castInput = perception.GetRayPerceptionInput();
var castOutput = RayPerceptionSensor.Perceive(castInput, true);

Assert.AreEqual(3, castOutput.RayOutputs.Length);

// Expected to hit the cube
Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name);
Assert.AreEqual(-1, castOutput.RayOutputs[0].HitTagIndex);
}
}

[Test]
public void TestCreateDefault()
{
Expand Down
3 changes: 3 additions & 0 deletions docs/Learning-Environment-Design-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ Both sensor components have several settings:
delta, ..., (n-1)*delta, n*delta). For general usage there is no difference
but if using custom models the left-to-right layout that matches the spatial
structuring can be preferred (e.g. for processing with conv nets).
- _Use Batched Raycasts_ (3D only) Whether to use batched raycasts. Enable to use batched raycasts and the jobs system.

In the example image above, the Agent has two `RayPerceptionSensorComponent3D`s.
Both use 3 Rays Per Direction and 90 Max Ray Degrees. One of the components had
Expand All @@ -528,6 +529,8 @@ setting the State Size.
for the agent that doesn't require a fully rendered image to convey.
- Use as few rays and tags as necessary to solve the problem in order to improve
learning stability and agent performance.
- If you run into performance issues, try using batched raycasts by enabling the _Use Batched Raycast_ setting.
(Only available for 3D ray perception sensors.)

### Grid Observations
Grid-base observations combine the advantages of 2D spatial representation in
Expand Down

0 comments on commit 039ece6

Please sign in to comment.