-
Notifications
You must be signed in to change notification settings - Fork 0
/
dirichlet_custom.py
64 lines (53 loc) · 2.45 KB
/
dirichlet_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import numpy as np
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.utils.annotations import override, DeveloperAPI
torch, nn = try_import_torch()
# This implementation is just necessary due to the incorrect implementation of the KL distance function for TorchDirichlet
@DeveloperAPI
class TorchDirichlet_Custom(TorchDistributionWrapper):
"""Dirichlet distribution for continuous actions that are between
[0,1] and sum to 1.
e.g. actions that represent resource allocation."""
def __init__(self, inputs, model):
"""Input is a tensor of logits. The exponential of logits is used to
parametrize the Dirichlet distribution as all parameters need to be
positive. An arbitrary small epsilon is added to the concentration
parameters to be zero due to numerical error.
See issue #4440 for more details.
"""
self.epsilon = torch.tensor(1e-7).to(inputs.device)
concentration = torch.exp(inputs) + self.epsilon
self.dist = torch.distributions.dirichlet.Dirichlet(
concentration=concentration,
validate_args=True,
)
super().__init__(concentration, model)
@override(ActionDistribution)
def deterministic_sample(self) -> TensorType:
#removed depreciated warning by adding dim=1
self.last_sample = nn.functional.softmax(self.dist.concentration, dim=1)
return self.last_sample
@override(ActionDistribution)
def logp(self, x):
# Support of Dirichlet are positive real numbers. x is already
# an array of positive numbers, but we clip to avoid zeros due to
# numerical errors.
x = torch.max(x, self.epsilon)
x = x / torch.sum(x, dim=-1, keepdim=True)
return self.dist.log_prob(x)
@override(ActionDistribution)
def entropy(self):
return self.dist.entropy()
# LEAVE THIS OUT, implementation is incorrect -> just go back to default of superclass
#@override(ActionDistribution)
#def kl(self, other):
# return self.dist.kl_divergence(other.dist)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape)