envs package¶
- envs.register_env(environment_name: str) Callable[[T], T] ¶
the decorator to register an environment
to add a new class as a registered environment, ensure the class is a derivative type of envs.BaseEnv, add the following code to the top of the file:
@register_env("env_name") YourEnvClass(BaseEnv): ...
- Parameters:
environment_name (str) – the name of the registered environment
- Returns:
the decorator to register the environment
- Return type:
Callable[[_T], _T]
Submodules¶
envs.BaseEnv module¶
- class envs.BaseEnv.BaseEnv(args: Namespace, device: Optional[str] = None)¶
Bases:
object
the base environment class, should be overridden by specific environments
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, device: Optional[str] = None) None ¶
initialize the environment
- Parameters:
args (argparse.Namespace) – arguments
device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.
- initialize_weight() None ¶
initialize the portfolio weight, risk free asset weight, and value
- get_asset_num() int ¶
get the number of assets, should be overridden by specific environments
- Raises:
NotImplementedError – asset num not implemented
- Returns:
the number of assets
- Return type:
int
- to(device: str) None ¶
move the environment to the given device
- Parameters:
device (torch.device) – the device to move to
- train_time_range() range ¶
the range of time indices, should be overridden by specific environments
- Raises:
NotImplementedError – time_range not implemented
- Returns:
the range of time indices
- Return type:
range
- test_time_range() range ¶
the range of time indices, should be overridden by specific environments
- Raises:
NotImplementedError – time_range not implemented
- Returns:
the range of time indices
- Return type:
range
- state_dimension() Dict[str, Size] ¶
the dimension of the state tensors, should be overridden by specific environments
- Raises:
NotImplementedError – state_dimension not implemented
- Returns:
the dimension of the state tensors
- Return type:
Dict[str, torch.Size]
- state_tensor_names() List[str] ¶
the names of the state tensors, should be overridden by specific environments
- Raises:
NotImplementedError – state_tensor_names not implemented
- Returns:
the names of the state tensors
- Return type:
List[str]
- action_dimension() Size ¶
the dimension of the action the agent can take
- Returns:
the dimension of the action the agent can take
- Return type:
torch.Size
- get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Dict[str, Union[Tensor, int]] ¶
get the state tensors at the current time, should be overridden by specific environments
- Parameters:
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Raises:
NotImplementedError – get_state not implemented
- Returns:
the state tensors
- Return type:
Dict[str, Union[torch.Tensor, int]]
- act(action: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool] ¶
update the environment with the given action at the given time, should be overridden by specific environments
- Parameters:
action (torch.Tensor) – the action to perform
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Raises:
NotImplementedError – act not implemented
- Returns:
the new state, reward, and whether the episode is done
- Return type:
Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]
- reset() None ¶
reset the environment, should be overridden by specific environments
- Raises:
NotImplementedError – reset not implemented
- update(trading_size: Tensor = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]] ¶
update the environment with the given trading size of each tensor
- Parameters:
trading_size (torch.Tensor) – the trading size of each asset. Defaults to None.
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.
- Returns:
the new state
- Return type:
Dict[str, Union[torch.Tensor, int]]
- _concat_weight(portfolio_weight: Tensor, rf_weight: Tensor) Tensor ¶
concat the portfolio weight and risk free weight, the risk free weight is at the first position
- Parameters:
portfolio_weight (torch.Tensor) – the portfolio weight
rf_weight (torch.Tensor) – the risk free weight
- Returns:
the concatenated weight
- Return type:
torch.Tensor
- _get_price_tensor(time_index: Optional[int] = None) Tensor ¶
get the price tensor at a given time, should be overridden by specific environments
- Parameters:
time_index (Optional[int], optional) – the time index to get the price tensor. Defaults to None, which means to get the price tensor at the current time.
- Raises:
NotImplementedError – _get_price_tensor not implemented
- Returns:
the price tensor
- Return type:
torch.Tensor
- _get_price_change_ratio_tensor(time_index: Optional[int] = None) tensor ¶
get the price change ratio tensor at a given time, should be overridden by specific environments
- Parameters:
time_index (Optional[int], optional) – the time index to get the price change ratio. Defaults to None, which means to get the price change ratio at the current time.
- Raises:
NotImplementedError – _get_price_change_ratio_tensor not implemented
- Returns:
the price change ratio tensor
- Return type:
torch.tensor
- _transaction_cost(trading_size: Tensor) Tensor ¶
compute the transaction cost of the trading
transaction_cost = sum(abs(trading_size) for trading_size > 0) * transaction_cost_rate_buy + sum(abs(trading_size) for trading_size < 0) * transaction_cost_rate_sell + count(trading_size != 0) * transaction_cost_base
- Parameters:
trading_size (torch.Tensor) – the trading size of each asset
- Returns:
the transaction cost
- Return type:
torch.Tensor
- _get_trading_size_according_to_weight_after_trade(portfolio_weight_before_trade: Tensor, portfolio_weight_after_trade: Tensor, portfolio_value_before_trade: Tensor) Tuple[Tensor, Tensor] ¶
find the trading size according to the weight before and after trading (don’t change day)
Reference: https://arxiv.org/abs/1706.10059 (Section 2.3)
- Parameters:
portfolio_weight_before_trade (torch.Tensor) – the weight before trading
portfolio_weight_after_trade (torch.Tensor) – the weight after trading
portfolio_value_before_trade (torch.Tensor) – the portfolio value before trading
- Returns:
the trading size and the mu
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- _get_new_portfolio_weight_and_value(trading_size: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] ¶
get the new portfolio weight and value after trading and transitioning to the next day
- Parameters:
trading_size (torch.Tensor) – the trading size of each asset
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Returns:
the new portfolio weight, the new portfolio weight at the next day, the new risk free weight, the new risk free weight at the next day, the new portfolio value, the new portfolio value at the next day, and the portfolio value at the next day with static weight
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- _cash_shortage(trading_size: Tensor, portfolio_value: Optional[Tensor] = None, rf_weight: Optional[Tensor] = None) bool ¶
assert whether there is cash shortage after trading
- Parameters:
trading_size (torch.Tensor) – the trading size of each asset
portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.
rf_weight (Optional[torch.Tensor], optional) – the risk free weight. Defaults to None.
- Returns:
whether there is cash shortage after trading
- Return type:
bool
- _asset_shortage(trading_size: Tensor, portfolio_weight: Optional[Tensor] = None, portfolio_value: Optional[Tensor] = None) bool ¶
assert whether there is asset shortage after trading
- Parameters:
trading_size (torch.Tensor) – the trading size of each asset
portfolio_weight (torch.Tensor) – the portfolio weight. default to None (Use the current portfolio weight)
portfolio_value (torch.Tensor) – the portfolio value. default to None (Use the current portfolio value)
- Returns:
whether there is asset shortage after trading
- Return type:
bool
- select_random_action() Tensor ¶
select a random action, should be overridden by specific environments
- Raises:
NotImplementedError – select_random_action not implemented
- Returns:
the random action
- Return type:
torch.Tensor
- get_momentum_action() Tensor ¶
get the momentum action, should be overridden by specific environments
- Raises:
NotImplementedError – get_momentum_action not implemented
- Returns:
the momentum action
- Return type:
torch.Tensor
- get_reverse_momentum_action() Tensor ¶
get the reverse momentum action, should be overridden by specific environments
- Raises:
NotImplementedError – get_reverse_momentum_action not implemented
- Returns:
the reverse momentum action
- Return type:
torch.Tensor
envs.BasicContinuousRealDataEnv module¶
- class envs.BasicContinuousRealDataEnv.BasicContinuousRealDataEnv(args: Namespace, data: Data, device: Optional[device] = None)¶
Bases:
BasicRealDataEnv
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, data: Data, device: Optional[device] = None) None ¶
initialize the environment
- Parameters:
args (argparse.Namespace) – arguments
data (Data) – data,
device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.
- to(device: str) None ¶
move the environment to the given device
- Parameters:
device (torch.device) – the device to move to
- state_dimension() Dict[str, Size] ¶
the dimension of the state tensors.
- Returns:
the dimension of the state tensors
- Return type:
Dict[str, torch.Size]
- state_tensor_names() List[str] ¶
the names of the state tensors
- Returns:
the names of the state tensors
- Return type:
List[str]
- get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Optional[Dict[str, Union[Tensor, int]]] ¶
get the state tensors at the current time.
- Parameters:
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Returns:
the state tensors
- Return type:
Dict[str, torch.Tensor]
- act(action: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool] ¶
perform an action (the trading size)
- Parameters:
action (torch.Tensor) – the action to perform
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Returns:
the new state, reward, and whether the episode is done
- Return type:
Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]
- update(action: Tensor = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]] ¶
update the environment
- Parameters:
action (torch.Tensor) – the action to perform. Defaults to None.
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.
- Returns:
the new state
- Return type:
Dict[str, Union[torch.Tensor, int]]
- select_random_action() Tensor ¶
select a random action
- Returns:
the random action
- Return type:
torch.Tensor
- get_momentum_action() Tensor ¶
get the momentum action
- Returns:
the momentum action
- Return type:
torch.Tensor
- get_reverse_momentum_action() Tensor ¶
get the reverse momentum action
- Returns:
the reverse momentum action
- Return type:
torch.Tensor
envs.BasicDiscreteRealDataEnv module¶
- class envs.BasicDiscreteRealDataEnv.BasicDiscreteRealDataEnv(args: Namespace, data: Data, device: Optional[device] = None)¶
Bases:
BasicRealDataEnv
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, data: Data, device: Optional[device] = None) None ¶
initialize the environment
- Parameters:
args (argparse.Namespace) – arguments
data (Data) – data,
device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.
- to(device: str) None ¶
move the environment to the given device
- Parameters:
device (torch.device) – the device to move to
- find_action_index(action: Tensor) int ¶
given an action, find the index of the action in all_actions
- Parameters:
action (torch.Tensor) – the trading decision of each asset
- Returns:
the index of the action in all_actions, -1 if not found
- Return type:
int
- act(action: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool] ¶
update the environment with the given action at the given time
- Parameters:
action (torch.tensor) – the action to take
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Returns:
the new state, reward, and whether the episode is done
- Return type:
Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]
- update(action: Tensor = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]] ¶
update the environment
- Parameters:
action (torch.Tensor) – the action to perform. Defaults to None.
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.
- Returns:
the new state
- Return type:
Dict[str, Union[torch.Tensor, int]]
- select_random_action() Tensor ¶
select a random action
- Returns:
the random action
- Return type:
torch.Tensor
- possible_actions(state: Dict[str, Tensor] = None) List[Tensor] ¶
get the possible actions
- Parameters:
state (Dict[str, torch.Tensor], optional) – the current state. Defaults to None.
- Returns:
the possible actions
- Return type:
List[torch.Tensor]
- get_momentum_action() Tensor ¶
get the momentum action
- Returns:
the momentum action
- Return type:
torch.Tensor
- get_reverse_momentum_action() Tensor ¶
get the reverse momentum action
- Returns:
the reverse momentum action
- Return type:
torch.Tensor
envs.BasicRealDataEnv module¶
- class envs.BasicRealDataEnv.BasicRealDataEnv(args: Namespace, data: Data, device: Optional[device] = None)¶
Bases:
BaseEnv
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, data: Data, device: Optional[device] = None) None ¶
initialize the environment
- Parameters:
args (argparse.Namespace) – arguments
data (Data) – data,
device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.
- to(device: str) None ¶
move the environment to the given device
- Parameters:
device (torch.device) – the device to move to
- get_asset_num() int ¶
get the number of assets, excluding risk-free asset
- Returns:
the number of assets
- Return type:
int
- train_time_range() range ¶
the range of time indices
- Returns:
the range of time indices
- Return type:
range
- test_time_range() range ¶
the range of time indices
- Returns:
the range of time indices
- Return type:
range
- _get_price_change_ratio_tensor(time_index: Optional[int] = None) Tensor ¶
get the price change ratio tensor at a given time
- Parameters:
time_index (Optional[int], optional) – the time index to get the price change ratio. Defaults to None, which means to get the price change ratio at the current time.
- Returns:
the price change ratio tensor
- Return type:
torch.Tensor
- _get_price_tensor(time_index: Optional[int] = None) Tensor ¶
get the price tensor at a given time
- Parameters:
time_index (Optional[int], optional) – the time index to get the price. Defaults to None, which means to get the price at the current time.
- Returns:
the price tensor
- Return type:
torch.Tensor
- _get_price_tensor_in_window(time_index: int) Tensor ¶
get the price tensor in a window centered at a given time
- Parameters:
time_index (int) – the time index to get the price
window_size (int) – the window size
- Returns:
the price tensor in the window
- Return type:
torch.Tensor
- _get_high_price_tensor_in_window(time_index: int) Tensor ¶
get the high price tensor in a window centered at a given time
- Parameters:
time_index (int) – the time index to get the high price
- Returns:
the high price tensor in the window
- Return type:
torch.Tensor
- _get_low_price_tensor_in_window(time_index: int) Tensor ¶
get the low price tensor in a window centered at a given time
- Parameters:
time_index (int) – the time index to get the low price
- Returns:
the low price tensor in the window
- Return type:
torch.Tensor
- _get_open_price_tensor_in_window(time_index: int) Tensor ¶
get the open price tensor in a window centered at a given time
- Parameters:
time_index (int) – the time index to get the open price
- Returns:
the open price tensor in the window
- Return type:
torch.Tensor
- _get_tensor_in_window(tensor: Tensor, time_index: int) Tensor ¶
get the tensor in a window centered at a given time
- Parameters:
tensor (torch.Tensor) – the tensor
time_index (int) – the time index to get the tensor
- Returns:
the tensor in the window
- Return type:
torch.Tensor
- state_dimension() Dict[str, Size] ¶
the dimension of the state tensors.
- Returns:
the dimension of the state tensors
- Return type:
Dict[str, torch.Size]
- state_tensor_names() List[str] ¶
the names of the state tensors
- Returns:
the names of the state tensors
- Return type:
List[str]
- get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Optional[Dict[str, Union[Tensor, int]]] ¶
get the state tensors at the current time.
- Parameters:
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Returns:
the state tensors
- Return type:
Dict[str, Union[torch.Tensor, int]]
- reset() None ¶
reset the environment.
envs.ContinuousRealDataEnv1 module¶
- class envs.ContinuousRealDataEnv1.ContinuousRealDataEnv1(args: Namespace, data: Data, device: Optional[device] = None)¶
Bases:
BasicContinuousRealDataEnv
- Reference:
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, data: Data, device: Optional[device] = None) None ¶
initialize the environment
- Parameters:
args (argparse.Namespace) – arguments
data (Data) – data,
device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.
- to(device: str) None ¶
move the environment to the given device
- Parameters:
device (torch.device) – the device to move to
- state_dimension() Dict[str, Size] ¶
the dimension of the state tensors.
- Returns:
the dimension of the state tensors
- Return type:
Dict[str, torch.Size]
- state_tensor_names() List[str] ¶
the names of the state tensors
- Returns:
the names of the state tensors
- Return type:
List[str]
- get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Optional[Dict[str, Union[Tensor, int]]] ¶
get the state tensors at the current time.
- Parameters:
state (Optional[Dict[str, torch.Tensor]], optional) – the state tensors. Defaults to None.
- Returns:
the state tensors
- Return type:
Dict[str, torch.Tensor]
- act(action_weight: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool] ¶
perform an action (the portfolio weight)
- Parameters:
action_weight (torch.Tensor) – the action to perform
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Returns:
the new state, reward, and whether the episode is done
- Return type:
Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]
- update(action_weight: Tensor = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]] ¶
update the environment
- Parameters:
action_weight (torch.Tensor) – the action to perform, means the weight after trade. Defaults to None.
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.
- Returns:
the new state
- Return type:
Dict[str, Union[torch.Tensor, int]]
- select_random_action() Tensor ¶
select a random action
- Returns:
the random action
- Return type:
torch.Tensor
- get_momentum_action() Tensor ¶
get the momentum action
- Returns:
the momentum action
- Return type:
torch.Tensor
- get_reverse_momentum_action() Tensor ¶
get the reverse momentum action
- Returns:
the reverse momentum action
- Return type:
torch.Tensor
envs.DiscreteRealDataEnv1 module¶
- class envs.DiscreteRealDataEnv1.DiscreteRealDataEnv1(args: Namespace, data: Data, device: Optional[device] = None)¶
Bases:
BasicDiscreteRealDataEnv
- Reference:
original paper: https://arxiv.org/abs/1907.03665
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, data: Data, device: Optional[device] = None) None ¶
initialize the environment
- Parameters:
args (argparse.Namespace) – arguments
data (Data) – data,
device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.
- to(device: str) None ¶
move the environment to the given device
- Parameters:
device (torch.device) – the device to move to
- sample_distribution_and_set_episode() int ¶
sample a distribution and set the episode accordingly please refer to paper https://arxiv.org/abs/1907.03665 for more details
- Returns:
the episode index
- Return type:
int
- set_episode(episode: int) None ¶
set the episode given the episode index
- Parameters:
episode (int) – the episode index
- set_episode_for_testing() None ¶
special function to set the episode for testing
- train_time_range() range ¶
the range of time indices
- Returns:
the range of time indices
- Return type:
range
- test_time_range() range ¶
the range of time indices
- Returns:
the range of time indices
- Return type:
range
- pretrain_train_time_range(shuffle: bool = True) List ¶
the list of time indices for pretraining
- Parameters:
shuffle (bool, optional) – whether to shuffle. Defaults to True.
- Returns:
the list of time indices
- Return type:
List
- pretrain_eval_time_range() range ¶
the list of time indices for pretraining evaluation
- Returns:
the list of time indices
- Return type:
range
- state_dimension() Dict[str, Size] ¶
the dimension of the state tensors, including Xt_Matrix and Portfolio_Weight
- Returns:
the dimension of the state tensors
- Return type:
Dict[str, torch.Size]
- state_tensor_names()¶
the names of the state tensors, including Xt_Matrix and Portfolio_Weight
- Returns:
the names of the state tensors
- Return type:
List[str]
- get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Optional[Dict[str, Union[Tensor, int]]] ¶
get the state tensors at the current time, including Xt_Matrix and Portfolio_Weight
- Parameters:
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Returns:
the state tensors
- Return type:
Dict[str, torch.Tensor]
- _get_Xt_state(time_index: Optional[int] = None) Tensor ¶
get the Xt state tensor at a given time
- Parameters:
time_index (Optional[int], optional) – the time_index. Defaults to None, which means to get the Xt state tensor at the current time.
- Returns:
the Xt state tensor
- Return type:
torch.Tensor
- get_pretrain_input_and_target(time_index: int) Optional[Tuple[Dict[str, Tensor], Tensor]] ¶
get the Xt state tensor and the pretrain target at a given time
- Parameters:
time_index (int) – the time index
- Returns:
the input state and the pretrain target
- Return type:
Optional[Tuple[Dict[str, torch.Tensor], torch.Tensor]]
- act(action_idx: int, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool] ¶
update the environment with the given action at the given time
- Parameters:
action_idx (int) – the id of the action to take
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
- Raises:
ValueError – action not valid
- Returns:
the new state, reward, and whether the episode is done
- Return type:
Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]
- update(action: int = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]] ¶
update the environment with the given action index
- Parameters:
action (int) – the id of the action to take. Defaults to None.
state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.
modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.
- Raises:
ValueError – the action is not valid
- Returns:
the new state
- Return type:
Dict[str, Union[torch.Tensor, int]]
- reset() None ¶
reset the environment to the initial state
- _cash_shortage(action: Tensor, portfolio_value: Optional[Tensor] = None, rf_weight: Optional[Tensor] = None) bool ¶
assert whether there is cash shortage after trading
- Parameters:
action (torch.Tensor) – the trading decision of each asset
portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.
rf_weight (Optional[torch.Tensor], optional) – the risk free weight. Defaults to None.
- Returns:
whether there is cash shortage after trading
- Return type:
bool
- _asset_shortage(action: Tensor, portfolio_weight: Optional[Tensor] = None, portfolio_value: Optional[Tensor] = None) bool ¶
assert whether there is asset shortage after trading
- Parameters:
action (torch.Tensor) – the trading decision of each asset
portfolio_weight (Optional[torch.Tensor], optional) – the portfolio weight. Defaults to None.
portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.
- Returns:
whether there is asset shortage after trading
- Return type:
bool
- _action_validity(action: Tensor, portfolio_weight: Optional[Tensor] = None, portfolio_value: Optional[Tensor] = None, rf_weight: Optional[Tensor] = None) bool ¶
assert whether the action is valid
- Parameters:
action (torch.Tensor) – the trading decision of each asset
portfolio_weight (Optional[torch.Tensor], optional) – the portfolio weight. Defaults to None.
portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.
rf_weight (Optional[torch.Tensor], optional) – the risk free weight. Defaults to None.
- Returns:
whether the action is valid
- Return type:
bool
- possible_actions() Tensor ¶
get all possible action indexes
- Returns:
all possible action indexes
- Return type:
torch.Tensor
- action_mapping(action_index: int, Q_Values: Tensor, state: Optional[Dict[str, Tensor]] = None) int ¶
perform action mapping based on the Q values
- Parameters:
action_index (int) – the index of the action to map
Q_Values (torch.Tensor) – the Q values of all actions
state (Optional[Dict[str, torch.Tensor]], optional) – the state tensors. Defaults to None.
- Raises:
ValueError – action not valid
- Returns:
the index of the mapped action
- Return type:
int
- _action_mapping_rule1(action: Tensor, Q_Values: Tensor, portfolio_value: Optional[Tensor] = None, rf_weight: Optional[Tensor] = None) int ¶
action mapping rule 1: if there is cash shortage, find the subset action with the highest Q value
- Parameters:
action (torch.Tensor) – the trading decision of each asset
Q_Values (torch.Tensor) – the Q values of all actions
portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.
rf_weight (Optional[torch.Tensor], optional) – the risk free weight. Defaults to None.
- Returns:
the index of the mapped action
- Return type:
int
- _action_mapping_rule2(action: Tensor, portfolio_weight: Tensor, portfolio_value: Tensor) int ¶
the action mapping rule 2: if there is asset shortage, don’t trade the asset with shortage
- Parameters:
action (torch.Tensor) – the trading decision of each asset
portfolio_weight (torch.Tensor) – the portfolio weight
portfolio_value (torch.Tensor) – the portfolio value
- Returns:
the index of the mapped action
- Return type:
int
- select_random_action() int ¶
select a random valid action, return its index
- Returns:
the index of the selected action
- Return type:
int
- get_momentum_action() int ¶
get the momentum action
- Returns:
the momentum action
- Return type:
int
- get_reverse_momentum_action() int ¶
get the reverse momentum action
- Returns:
the reverse momentum action
- Return type:
int
envs.DiscreteRealDataEnv2 module¶
- class envs.DiscreteRealDataEnv2.DiscreteRealDataEnv2(args: Namespace, data: Data, device: Optional[device] = None)¶
Bases:
BasicDiscreteRealDataEnv
- static add_args(parser: ArgumentParser) None ¶
add arguments to the parser
to add arguments to the parser, modify the method as follows:
@staticmethod def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( ... )
then add arguments to the parser
- Parameters:
parser (argparse.ArgumentParser) – the parser to add arguments to
- __init__(args: Namespace, data: Data, device: Optional[device] = None) None ¶
initialize the environment
- Parameters:
args (argparse.Namespace) – arguments
data (Data) – data,
device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.
- action_is_valid(action: Tensor, state: Optional[Dict[str, Tensor]] = None) bool ¶
check if the action is valid
- Parameters:
action (torch.Tensor) – the action
state (Optional[Dict[str, torch.Tensor]], optional) – the current state. Defaults to None.
- Returns:
whether the action is valid
- Return type:
bool
- possible_actions(state: Dict[str, Tensor] = None) List[Tensor] ¶
get the possible actions
- Parameters:
state (Dict[str, torch.Tensor], optional) – the current state. Defaults to None.
- Returns:
the possible actions
- Return type:
List[torch.Tensor]
- get_momentum_action() Tensor ¶
get the momentum action
- Returns:
the momentum action
- Return type:
torch.Tensor
- get_reverse_momentum_action() Tensor ¶
get the reverse momentum action
- Returns:
the reverse momentum action
- Return type:
torch.Tensor